WER on ljspeech, 5-30min subsets, take 2
WER check using AWB dataset
Original on Kaggle
%%capture
!pip install transformers datasets jiwer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from datasets import Dataset
import soundfile as sf
import torch
from jiwer import wer
test_ids = []
with open("../input/create-ljspeech-splits/test.tsv") as tsvf:
for line in tsvf.readlines()[1:]:
parts = line.split("\t")
test_ids.append(parts[0].replace(".wav", ""))
transcripts = {}
with open("../input/ljspeech-for-asr/transcripts.tsv") as tsf:
for line in tsf.readlines():
parts = line.strip().split("\t")
transcripts[parts[0]] = parts[1].upper()
paths = []
text = []
for id in test_ids:
paths.append(f"/kaggle/input/ljspeech-for-asr/wav16/{id}.wav")
text.append(transcripts[id])
dataset = Dataset.from_dict({"file": paths, "text": text})
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
dataset = dataset.map(map_to_array)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("jimregan/wav2vec-ljspeech-splits", revision="25mins-2")
def map_to_pred(batch):
input_values = tokenizer(batch["speech"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = dataset.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))
model = Wav2Vec2ForCTC.from_pretrained("jimregan/wav2vec-ljspeech-splits", revision="30mins-2")
def map_to_pred(batch):
input_values = tokenizer(batch["speech"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = dataset.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))