import soundfile as sf
import wave
def smp_headers(filename: str):
    with open(filename, "rb") as f:
        f.seek(0)
        raw_headers = f.read(1024)
        raw_headers = raw_headers.rstrip(b'\x00')
        asc_headers = raw_headers.decode("ascii")
        asc_headers.rstrip('\x00')
        tmp = [a for a in asc_headers.split("\r\n")]
        back = -1
        while abs(back) > len(tmp) + 1:
            if tmp[back] == '=':
                break
            back -= 1
        tmp = tmp[0:back-1]
        return dict(a.split("=") for a in tmp)


def smp_read_sf(filename: str):
    headers = smp_headers(filename)
    if headers["msb"] == "last":
        ENDIAN = "LITTLE"
    else:
        ENDIAN = "BIG"

    data, sr = sf.read(filename, channels=int(headers["nchans"]),
                       samplerate=16000, endian=ENDIAN, start=512,
                       dtype="int16", format="RAW", subtype="PCM_16")
    return (data, sr)


def write_wav(filename, arr):
    with wave.open(filename, "w") as f:
        f.setnchannels(1)
        f.setsampwidth(2)
        f.setframerate(16000)
        f.writeframes(arr)
from pathlib import Path
WAXHOLM = "/Users/joregan/Playing/waxholm"
OUTPUT = "/Users/joregan/Playing/waxholm_fairseq"
SCENES_PATH = Path(WAXHOLM) / "scenes_formatted"
OUTPUT_PATH = Path(OUTPUT)
if not OUTPUT_PATH.is_dir():
    OUTPUT_PATH.mkdir()
TRAIN_FILES = []
with open(Path(WAXHOLM) / "alloktrainfiles") as trainf:
    for line in trainf.readlines():
        TRAIN_FILES.append(line.strip())
TEST_FILES = []
with open(Path(WAXHOLM) / "testfiles") as testf:
    for line in testf.readlines():
        TEST_FILES.append(line.strip())
print(len(TRAIN_FILES), len(TEST_FILES))
1835 327
import re

def get_labels(mixfile):
    labels = ""
    saw_label = False
    with open(mixfile) as infile:
        for line in infile.readlines():
            if not saw_label:
                if line.lower().startswith("labels:"):
                    saw_label = True
                    labels = line[7:].strip()
            else:
                if line.startswith("FR"):
                    break
                else:
                    labels = " ".join([labels, line.strip()])
        labels = re.sub("  +", " ", labels)
    return labels
get_labels("/Users/joregan/Playing/waxholm/scenes_formatted/fp2043/fp2043.16.03.smp.mix")
'A:H\'A: pa p: |h J\'A:Ggv V\'ILv pap: sm p:v S\'E: pa H\'U:R 2Dd\'EM Bb\']:TtE0NG Gg\']:R 2Tt\'I STt"A:VE0#STtR`\\M p: \']: p: \']M J\'A: Kk\'AN F"O#2S`[TtA Tt\'I F"IN#H`AM .'
def segment_label(label, skip_pause=True):
    phones = []
    i = 0
    while i < len(label):
        start_i = i
        end_i = i
        if label[i:i+2] in ["NG", "E0", "SJ", "TJ", "kl", "sm", "kl", "pa"]:
            phones.append(label[i:i+2])
            i += 2
        elif label[i:i+2] == "p:":
            if not skip_pause:
                phones.append("p:")
            i += 2
        elif label[i:i+1] == "#":
            i += 1
        else:
            if label[i:i+1] in ["'", "`", "\"", "2", "~"]:
                i += 1
                end_i += 1
            if label[i+1:i+2] in [":", "3", "4"]:
                end_i += 1
            phones.append(label[start_i:end_i+1])
            i = end_i + 1
    return phones
assert segment_label("Bb\']:TtE0NG") == ['B', 'b', "']:", 'T', 't', 'E0', 'NG']
assert segment_label("STt\"A:VE0#STtR`\\M") == ['S', 'T', 't', '"A:', 'V', 'E0', 'S', 'T', 't', 'R', '`\\', 'M']
assert segment_label("p:v") == ['v']
def proc_label(label, stress=False):
    def strip_stress(phone, stress):
        if stress:
            return phone
        if phone[0] in ["'", "`", "\""]:
            return phone[1:]
        else:
            return phone
    words = []
    for word in label.split(" "):
        if word in ["p:pa", "pap:", "p:pap:", "pa"]:
            words.append("pa")
        elif word == "p:" or word == ".":
            continue
        elif word == "|h":
            words.append("hes")
        elif word in ["sm", "ha", "kl"]:
            words.append(word)
        else:
            phones = [strip_stress(p, stress) for p in segment_label(word)]
            words.append(" ".join(phones))
    return(" | ".join(words)) + " |"
lbl = get_labels("/Users/joregan/Playing/waxholm/scenes_formatted/fp2043/fp2043.16.03.smp.mix")
plbl = proc_label(lbl)
print(lbl)
print(plbl)
A:H'A: pa p: |h J'A:Ggv V'ILv pap: sm p:v S'E: pa H'U:R 2Dd'EM Bb']:TtE0NG Gg']:R 2Tt'I STt"A:VE0#STtR`\M p: ']: p: ']M J'A: Kk'AN F"O#2S`[TtA Tt'I F"IN#H`AM .
A: H A: | pa | hes | J A: G g v | V I L v | pa | sm | v | S E: | pa | H U: R | 2D d E M | B b ]: T t E0 NG | G g ]: R | 2T t I | S T t A: V E0 S T t R \ M | ]: | ] M | J A: | K k A N | F O 2S [ T t A | T t I | F I N H A M |
with open(OUTPUT_PATH / "train.tsv", "w") as train_tsv,\
     open(OUTPUT_PATH / "train.ltr", "w") as train_ltr,\
     open(OUTPUT_PATH / "valid.tsv", "w") as valid_tsv,\
     open(OUTPUT_PATH / "valid.ltr", "w") as valid_ltr,\
     open(OUTPUT_PATH / "test.tsv", "w") as test_tsv,\
     open(OUTPUT_PATH / "test.ltr", "w") as test_ltr:
    train_tsv.write(str(OUTPUT_PATH) + "\n")
    test_tsv.write(str(OUTPUT_PATH) + "\n")
    valid_tsv.write(str(OUTPUT_PATH) + "\n")
    valid_amount = 195
    for smpfile in SCENES_PATH.glob("fp*/*.smp"):
        mixfile = f"{smpfile}.mix"
        if not Path(mixfile).exists():
            continue
        stem = smpfile.stem
        if f"{stem}.smp" in TEST_FILES:
            out_tsv = test_tsv
            out_ltr = test_ltr
        elif valid_amount > 0:
            out_tsv = valid_tsv
            out_ltr = valid_ltr
            valid_amount -= 1
        else:
            out_tsv = train_tsv
            out_ltr = train_ltr

        outwav = f"{stem}.wav"
        arr, sr = smp_read_sf(str(smpfile))
        out_tsv.write(f"{outwav}\t{len(arr)}\n")
        write_wav(outwav, arr)
        label = get_labels(mixfile)
        ltrline = proc_label(label)
        out_ltr.write(ltrline + "\n")