import torch
from TTS.api import TTS
from tqdm.auto import tqdm
import soundfile as sf
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
TSV_DIR = "/home/joregan/humogen/genea2023/tsv"
REF_DIR = "/home/joregan/humogen/reference_speakers"
from pathlib import Path

sentences = {}
for tsvfile in Path(TSV_DIR).glob("*.tsv"):
    file_id = tsvfile.stem
    lines = []
    with open(tsvfile) as f:
        for line in f.readlines():
            if not "\t" in line:
                continue
            lines.append(line.strip().split("\t"))
    sentences[file_id] = " ".join([x[2] for x in lines if len(x) == 3])
for spk_no, speaker in enumerate(tqdm(Path(REF_DIR).glob("*.wav"))):
    folder = Path(f"/tmp/outputs/{spk_no + 1}")
    folder.mkdir(exist_ok=True)
    print(speaker)
    for i, sentence in enumerate(tqdm(sentences)):
        wav = tts.tts(text=sentences[sentence], speaker_wav=f"{speaker}", language="en", split_sentences=False)
        sf.write(folder / f"{i+1}.wav", np.array(wav), 24_000, 'PCM_24')