import json

with open("/tmp/waxholm_raw_lexicon.json") as lexjson:
    data = json.load(lexjson)
def simplify_stops(text):
    text = text.replace("Kk", "K")
    text = text.replace("Gg", "G")
    text = text.replace("Dd", "D")
    text = text.replace("Tt", "T")
    text = text.replace("Bb", "B")
    text = text.replace("Pp", "P")
    text = text.replace("k", "K")
    text = text.replace("Kl", "kl")
    text = text.replace("g", "G")
    text = text.replace("d", "D")
    text = text.replace("t", "T")
    text = text.replace("b", "B")
    text = text.replace("p", "P")
    text = text.replace("Pa", "pa")
    text = text.replace("P:", "p:")
    return text
def simplify_phoneme(text):
    text = text.replace("+", "")
    text = text.replace("hy", "#")
    return text
def segment_label(label, skip_pause=True):
    phones = []
    i = 0
    while i < len(label):
        start_i = i
        end_i = i
        if label[i:i+2] in ["NG", "E0", "SJ", "TJ", "kl", "sm", "pa", "ha", "öh", "Pa", "p:_pa", "pa_p:"]:
            phones.append(label[i:i+2])
            i += 2
        elif label[i:i+2] == "p:":
            if not skip_pause:
                phones.append("p:")
            i += 2
        elif label[i:i+1] in ["#", "~"]:
            i += 1
        else:
            if label[i:i+1] in ["'", "`", "\"", ",", "2"]:
                i += 1
                end_i += 1
            if label[i+1:i+2] in [":", "3", "4"]:
                end_i += 1
            phones.append(label[start_i:end_i+1])
            i = end_i + 1
    return phones
def lclem(lower):
    if lower[0] == lower[-1] == "X":
        return lower
    else:
        return lower.lower()
data[0]
{'stem': 'fp2024.1.03',
 'smp': 'fp2024/fp2024.1.03.smp',
 'text': 'tack det är bra',
 'phoneme': "T'AK D'E:T+ 'Ä3R+ BR'A:",
 'labels': "Tt'AKk Dd'E: BbR'A:",
 'labels_original': "Tt'AKk Dd'E: BbR'A:"}
X_TAGS = {
    "XtvekX": "öh",
    "XinandX": "pa",
    "XsmackX": "sm",
    "XutandX": "pa",
    "XharklingX": "ha",
    "XklickX": "kl",
    "XavbrordX": "",
    "XskrattX": "ha",
    "XsuckX": "pa"
}
def check_x_tag(word, phoneme):
    if word == "XavbrordX":
        return True
    if word in X_TAGS:
        return phoneme == X_TAGS[word]

def modify_phonemes_inner(word, phonemes, idx):
    if not word in X_TAGS:
        return phonemes
    elif word == "XavbrordX":
        return phonemes
    if not check_x_tag(word, phonemes[idx]):
        if idx == -1:
            return phonemes + [X_TAGS[word]]
        else:
            return phonemes[0:idx] + [X_TAGS[word]] + phonemes[idx:]
wds = "XinandX lördag".split(" ")
phn = "L'Ö32DA".split(" ")
assert modify_phonemes_inner(wds[0], phn, 0) == ['pa', "L'Ö32DA"]
wds = "lördag XinandX".split(" ")
phn = "L'Ö32DA".split(" ")
assert modify_phonemes_inner(wds[1], phn, -1) == ["L'Ö32DA", 'pa']
def modify_phonemes(words, phonemes):
    i = 0
    if phonemes is None or phonemes == [] or len(phonemes) == 0:
        print("Error with phonemes", phonemes)
        return []
    my_phonemes = phonemes
    assert isinstance(my_phonemes, list)
    assert my_phonemes is not None
    while i < len(words):
        if my_phonemes is None:
            print(words, phonemes, my_phonemes, type(my_phonemes))
        if i >= len(my_phonemes):
            p_i = -1
        else:
            p_i = i
        my_phonemes = modify_phonemes_inner(words[i], my_phonemes, p_i)
        i += 1
    return my_phonemes
            
for entry in data:
    if not "phoneme" in entry:
        continue
    if entry["phoneme"] == "":
        continue
    phonemes = entry["phoneme"].strip().split(" ")
    phonemes_orig = " ".join(phonemes)
    words = entry["text"].strip().split(" ")
    labels = entry["labels"].strip().split(" ")
    if phonemes == None or phonemes == []:
        continue
    if len(words) != len(phonemes):
        mod = modify_phonemes(words, phonemes)
        if " ".join(mod) != phonemes_orig:
            entry["phoneme_orginal"] = entry["phoneme"]
            entry["phoneme"] = " ".join(mod)
    with open("/tmp/waxholm_autoedit.json", "w") as autoedit:
        json.dump(data, autoedit)
entries = {}
rest = []
for item in data:
    if not "phoneme" in item:
        continue
    if item["labels"].startswith("sm") and not item["text"].startswith("XsmackX"):
        item["text"] = f'XsmackX {item["text"]}'
    elif item["labels"].startswith("öh") and not item["text"].startswith("XtvekX"):
        item["text"] = f'XtvekX {item["text"]}'
    # elif item["labels"].startswith("pa") and not item["text"].startswith("XutandX"):
    #     item["text"] = f'XutandX {item["text"]}'
    phonemes = simplify_phoneme(item["phoneme"]).split(" ")
    labels = simplify_stops(item["labels"]).split(" ")
    words = [lclem(x) for x in item["text"].split(" ")]

    if len(phonemes) == len(labels) == len(words):
        curword = {}
        for x in zip(words, phonemes, labels):
            if not x[0] in entries:
                entries[x[0]] = {}
            if not x[1] in entries[x[0]]:
                entries[x[0]][x[1]] = {}
            if not x[2] in entries[x[0]][x[1]]:
                entries[x[0]][x[1]][x[2]] = set()
            entries[x[0]][x[1]][x[2]].add(item["stem"])
    else:
        rest.append(item)
for a in entries:
    for b in entries[a]:
        for c in entries[a][b]:
            for d in entries[a][b][c]:
                entries[a][b][c] = list(entries[a][b][c])
with open("/tmp/simple-aligned-entries.json", "w") as simplef:
    json.dump(entries, simplef)
len(rest)
429
smacks = []
uh = []
for item in rest:
    if item["labels"].startswith("sm") and not item["text"].startswith("XsmackX"):
        smacks.append(item)
    elif item["labels"].startswith("öh") and not item["text"].startswith("XtvekX"):
        uh.append(item)