%%capture
!pip install librosa webrtcvad

The VAD wrapper is taken from PyTorch Speaker Verification, which is in turn is based on py-webrtcvad.

# VAD wrapper is taken from PyTorch Speaker Verification:
# https://github.com/HarryVolek/PyTorch_Speaker_Verification
# Copyright (c) 2019, HarryVolek
# License: BSD-3-Clause
# based on https://github.com/wiseman/py-webrtcvad/blob/master/example.py
# Copyright (c) 2016 John Wiseman
# License: MIT
import collections
import contextlib
import numpy as np
import sys
import librosa
import wave

import webrtcvad

#from hparam import hparam as hp
sr = 16000

def read_wave(path, sr):
    """Reads a .wav file.
    Takes the path, and returns (PCM audio data, sample rate).
    Assumes sample width == 2
    """
    with contextlib.closing(wave.open(path, 'rb')) as wf:
        num_channels = wf.getnchannels()
        assert num_channels == 1
        sample_width = wf.getsampwidth()
        assert sample_width == 2
        sample_rate = wf.getframerate()
        assert sample_rate in (8000, 16000, 32000, 48000)
        pcm_data = wf.readframes(wf.getnframes())
    data, _ = librosa.load(path, sr)
    assert len(data.shape) == 1
    assert sr in (8000, 16000, 32000, 48000)
    return data, pcm_data
    
class Frame(object):
    """Represents a "frame" of audio data."""
    def __init__(self, bytes, timestamp, duration):
        self.bytes = bytes
        self.timestamp = timestamp
        self.duration = duration


def frame_generator(frame_duration_ms, audio, sample_rate):
    """Generates audio frames from PCM audio data.
    Takes the desired frame duration in milliseconds, the PCM data, and
    the sample rate.
    Yields Frames of the requested duration.
    """
    n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
    offset = 0
    timestamp = 0.0
    duration = (float(n) / sample_rate) / 2.0
    while offset + n < len(audio):
        yield Frame(audio[offset:offset + n], timestamp, duration)
        timestamp += duration
        offset += n


def vad_collector(sample_rate, frame_duration_ms,
                  padding_duration_ms, vad, frames):
    """Filters out non-voiced audio frames.
    Given a webrtcvad.Vad and a source of audio frames, yields only
    the voiced audio.
    Uses a padded, sliding window algorithm over the audio frames.
    When more than 90% of the frames in the window are voiced (as
    reported by the VAD), the collector triggers and begins yielding
    audio frames. Then the collector waits until 90% of the frames in
    the window are unvoiced to detrigger.
    The window is padded at the front and back to provide a small
    amount of silence or the beginnings/endings of speech around the
    voiced frames.
    Arguments:
    sample_rate - The audio sample rate, in Hz.
    frame_duration_ms - The frame duration in milliseconds.
    padding_duration_ms - The amount to pad the window, in milliseconds.
    vad - An instance of webrtcvad.Vad.
    frames - a source of audio frames (sequence or generator).
    Returns: A generator that yields PCM audio data.
    """
    num_padding_frames = int(padding_duration_ms / frame_duration_ms)
    # We use a deque for our sliding window/ring buffer.
    ring_buffer = collections.deque(maxlen=num_padding_frames)
    # We have two states: TRIGGERED and NOTTRIGGERED. We start in the
    # NOTTRIGGERED state.
    triggered = False

    voiced_frames = []
    for frame in frames:
        is_speech = vad.is_speech(frame.bytes, sample_rate)

        if not triggered:
            ring_buffer.append((frame, is_speech))
            num_voiced = len([f for f, speech in ring_buffer if speech])
            # If we're NOTTRIGGERED and more than 90% of the frames in
            # the ring buffer are voiced frames, then enter the
            # TRIGGERED state.
            if num_voiced > 0.9 * ring_buffer.maxlen:
                triggered = True
                start = ring_buffer[0][0].timestamp
                # We want to yield all the audio we see from now until
                # we are NOTTRIGGERED, but we have to start with the
                # audio that's already in the ring buffer.
                for f, s in ring_buffer:
                    voiced_frames.append(f)
                ring_buffer.clear()
        else:
            # We're in the TRIGGERED state, so collect the audio data
            # and add it to the ring buffer.
            voiced_frames.append(frame)
            ring_buffer.append((frame, is_speech))
            num_unvoiced = len([f for f, speech in ring_buffer if not speech])
            # If more than 90% of the frames in the ring buffer are
            # unvoiced, then enter NOTTRIGGERED and yield whatever
            # audio we've collected.
            if num_unvoiced > 0.9 * ring_buffer.maxlen:
                triggered = False
                yield (start, frame.timestamp + frame.duration)
                ring_buffer.clear()
                voiced_frames = []
    # If we have any leftover voiced audio when we run out of input,
    # yield it.
    if voiced_frames:
        yield (start, frame.timestamp + frame.duration)


def VAD_chunk(aggressiveness, path):
    audio, byte_audio = read_wave(path, sr)
    vad = webrtcvad.Vad(int(aggressiveness))
    frames = frame_generator(20, byte_audio, sr)
    frames = list(frames)
    times = vad_collector(sr, 20, 200, vad, frames)
    speech_times = []
    speech_segs = []
    for i, time in enumerate(times):
        start = np.round(time[0],decimals=2)
        end = np.round(time[1],decimals=2)
        j = start
        while j + .4 < end:
            end_j = np.round(j+.4,decimals=2)
            speech_times.append((j, end_j))
            speech_segs.append(audio[int(j*sr):int(end_j*sr)])
            j = end_j
        else:
            speech_times.append((j, end))
            speech_segs.append(audio[int(j*sr):int(end*sr)])
    return speech_times, speech_segs

Running

I'm going to use a video from YouTube as my input, so first I need to install youtube-dl

%%capture
!pip install youtube-dl

I've selected this video because it's a speech by the President of Ireland (and so copyright-free as a matter of public record), it has subtitles (in Irish, though listed as English), and the subtitles are quite faithful to what was spoken.

%%capture
!youtube-dl --all-subs -o '%(id)s' VRg-a0qSGa8

The audio needs to be a 16k wav, so I'm converting it with ffmpeg.

%%capture
!ffmpeg -i VRg-a0qSGa8.mkv -acodec pcm_s16le -ac 1 -ar 16000 VRg-a0qSGa8.wav

Next, I'm using the VAD_chunk() function to get the start and end times, and audio segements of each part of the video with speech.

times, segs = VAD_chunk(3, 'VRg-a0qSGa8.wav')

The wav2vec2 models generally perform badly on short input, so vad_concat() concatenates the segments, as well as the times (for DSAlign).

# Based on code from PyTorch Speaker Verification:
# https://github.com/HarryVolek/PyTorch_Speaker_Verification
# Copyright (c) 2019, HarryVolek
# Additions Copyright (c) 2021, Jim O'Regan
# License: MIT
import numpy as np

# wav2vec2's max duration is 40 seconds, using 39 by default
# to be a little safer
def vad_concat(times, segs, max_duration=39.0):
    """
    Concatenate continuous times and their segments, where the end time
    of a segment is the same as the start time of the next
        Parameters:
            times: list of tuple (start, end)
            segs: list of segments (audio frames)
            max_duration: maximum duration of the resulting concatenated
                segments; the kernel size of wav2vec2 is 40 seconds, so
                the default max_duration is 39, to ensure the resulting
                list of segments will fit
        Returns:
            concat_times: list of tuple (start, end)
            concat_segs: list of segments (audio frames)
    """
    absolute_maximum=40.0
    if max_duration > absolute_maximum:
        raise Exception('`max_duration` {:.2f} larger than kernel size (40 seconds)'.format(max_duration))
    # we take 0.0 to mean "don't concatenate"
    do_concat = (max_duration != 0.0)
    concat_seg = []
    concat_times = []
    seg_concat = segs[0]
    time_concat = times[0]
    for i in range(0, len(times)-1):
        can_concat = (times[i+1][1] - time_concat[0]) < max_duration
        if time_concat[1] == times[i+1][0] and do_concat and can_concat:
            seg_concat = np.concatenate((seg_concat, segs[i+1]))
            time_concat = (time_concat[0], times[i+1][1])
        else:
            concat_seg.append(seg_concat)
            seg_concat = segs[i+1]
            concat_times.append(time_concat)
            time_concat = times[i+1]
    else:
        concat_seg.append(seg_concat)
        concat_times.append(time_concat)
    return concat_times, concat_seg

ntimes, nsegs = vad_concat(times, segs)

Next, I'm putting the data into a dict that Huggingface datasets can read:

starts = [s[0] for s in ntimes]
ends = [s[1] for s in ntimes]
dset = {'start': starts,
        'end': ends,
        'speech': nsegs}
%%capture
!pip install datasets
from datasets import Dataset
dataset = Dataset.from_dict(dset)
dataset
Dataset({
    features: ['start', 'end', 'speech'],
    num_rows: 137
})

Now, the data is ready to plug into my wav2vec2 model.

%%capture
!pip install -q transformers
%%capture
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained("jimregan/wav2vec2-large-xlsr-irish-basic")
model = Wav2Vec2ForCTC.from_pretrained("jimregan/wav2vec2-large-xlsr-irish-basic")
model.to("cuda")
Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.

def speech_file_to_array_fn(batch):
    import torchaudio
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["sentence"]
    return batch
def evaluate(batch):
  import torch
  inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

  with torch.no_grad():
    logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_strings"] = processor.batch_decode(pred_ids)
  return batch

result = dataset.map(evaluate, batched=True, batch_size=8)

speechless = result.remove_columns(['speech'])
d=speechless.to_dict()
tlog = list()
for i in range(0, len(d['end']) - 1):
  out = dict()
  out['start'] = d['start'][i]
  out['end'] = d['end'][i]
  out['transcript'] = d['pred_strings'][i]
  tlog.append(out)
import json
with open('/content/VRg-a0qSGa8.tlog', 'w') as outfile:
    json.dump(tlog, outfile)

Next, I'm extracting the text content from the vtt file

!pip install webvtt-py
Requirement already satisfied: webvtt-py in /usr/local/lib/python3.7/dist-packages (0.4.6)
Requirement already satisfied: docopt in /usr/local/lib/python3.7/dist-packages (from webvtt-py) (0.6.2)
def get_vtt_text(filename):
  import webvtt
  out = list()
  for sub in webvtt.read(filename):
    out.append(sub.text)
  return ' '.join(out)
text = get_vtt_text('/content/VRg-a0qSGa8.en.vtt')

I can do some normalisation now:

text = text.replace('1901', 'naoi déag is a haon')
text = text.replace('2021', 'fiche is fiche is a haon')
text = text.replace('Covid-19', 'covid a naoi déag')
text = text.replace('fiche fiche haon', 'fiche is fiche is a haon')

I want sentences, so I'm going to use mosestokenizer to split the text (there aren't any specific abbreviations in this video, so the English splitter works fine. YMMV.)

%%capture
!pip install mosestokenizer

The actual moses tokeniser has sentence splitting support for Irish, but the Python version was forked before that; we don't actually need any specific support for Irish here, so we can just use English.

from mosestokenizer import MosesSentenceSplitter
with MosesSentenceSplitter('en') as splitsents:
  sents = splitsents([text])
with open('/content/VRg-a0qSGa8.txt', 'w') as outfile:
  outfile.writelines(['\n'.join(sents)])

DSAlign requires an alphabet (1 character per line), so create that first

alpha="aábcdeéfghiíjklmnoópqrstuúvwxyz'-"
alpha_chars = [char for char in alpha]
with open('/content/ga.alphabet', 'w') as outfile:
  outfile.writelines(['\n'.join(alpha_chars)])

Now, to install DSAlign and its dependencies:

%%capture
!git clone https://github.com/mozilla/DSAlign
%%capture
!apt-get install sox
%%capture
import os
os.chdir('DSAlign')
!pip install -r requirements.txt

Now, I'm ready to align:

!bin/align.sh --force --tlog /content/VRg-a0qSGa8.tlog --script /content/VRg-a0qSGa8.txt --aligned /content/VRg-a0qSGa8.aligned --text-meaningful-newlines --alphabet /content/ga.alphabet
bin/align.sh: line 3: /content/DSAlign/venv/bin/activate: No such file or directory
INFO:root:Aligning
 1 of 1 : 100.00% (elapsed: 00:00:04, speed: 0.25 it/s, ETA: 00:00:00)
INFO:root:Aligned 24 fragments
INFO:root:Dropped 112 fragments 466.67%:

24 out of 136 fragments isn't great, but it's quite good considering the WER of the model (43.7%); the next step would be to add the aligned data to the training set, retrain, and repeat.