dictfile = "/Users/joregan/Documents/MFA/pretrained_models/dictionary/hsi_english_us_arpa.dict"
tgfiles = "/Users/joregan/Playing/hsi_ctmedit/textgrid/"
outdict = "/tmp/added.dict"
EXP = {
    "JH": "Y",
    "W": "V",
    "V": "B",
    "B": "V"
}

TRAP_BATH = ["S", "F", "TH", "N", "V", "Z", "SH", "M", "L", "DH"]

# everything but 'ER'
PURE_VOWELS = [
    "AA",
    "AE",
    "AH",
    "AO",
    "AW",
    "AY",
    "EH",
    "EY",
    "IH",
    "IY",
    "OW",
    "OY",
    "UH",
    "UW"
]


def is_vowel(phone, pure=True):
    if not pure and phone.startswith("ER"):
        return True
    if phone[-1:] in ["0", "1", "2"]:
        phone = phone[:-1]
    return phone in PURE_VOWELS


def expand_inner(text):
    phones = text.split(" ")
    outphones = []
    if len(phones) >= 2 and phones[0] == "S" and phones[1] in ["T", "P", "M"]:
        outphones.append(["EH0", ""])
    i = 0
    next_opt = False
    for i in range(len(phones)):
        cur = [phones[i]]
        if next_opt:
            cur.append("")
            next_opt = False
        if phones[i] in EXP:
            cur.append(EXP[phones[i]])
        if i < len(phones) - 1 and phones[i] in ["S", "N"] and phones[i+1] == "T":
            next_opt = True
        if i < len(phones) - 1 and phones[i].startswith("AE") and phones[i+1] in TRAP_BATH:
            cur.append("AA" + phones[i][-1])
        if i < len(phones) - 1 and phones[i].startswith("AA") and phones[i+1] == "R":
            cur.append("AE" + phones[i][-1])
        if phones[i].startswith("ER"):
            emph = phones[i][-1]
            cur.append("EH" + emph)
            cur.append("AH" + emph)
        if i < len(phones) - 1 and is_vowel(phones[i]) and phones[i+1] == "R":
            next_opt = True
        if i == 0 and is_vowel(phones[i], False):
            outphones = [["HH", ""]]

        outphones.append(cur)
    return outphones
import itertools

def expand(text):
    return [list(x) for x in itertools.product(*expand_inner(text))]
entries = {}
with open(dictfile) as df:
    for line in df.readlines():
        parts = line.strip().split("\t")
        if not parts[0] in entries:
            entries[parts[0]] = []
        entries[parts[0]].append(parts[-1])
from pathlib import Path
from praatio import textgrid
import re

def norm(text):
    words = text.split(" ")
    words = [w.strip("\",.;:?!").upper() for w in words if not w.startswith("[")]
    return words


seen_words = []
missing = []
new_entries = set()

for textgridfile in Path(tgfiles).glob("*.[Tt]ext[Gg]rid"):
    tg = textgrid.openTextgrid(textgridfile, includeEmptyIntervals=False)

    if len(tg.tierNames) == 1:
        tier = tg.getTier(tg.tierNames[0])
    elif "whisperx" in tg.tierNames:
        tier = tg.getTier("whisperx")
    elif "utterances" in tg.tierNames:
        tier = tg.getTier("utterances")
    elif "words" in tg.tierNames:
        tier = tg.getTier("words")
    else:
        print("Be careful: file", textgridfile, "has none of the expected tier names")

    for interval in tier.entries:
        start = interval[0]
        end = interval[1]
        text = interval[2]

        m = re.match("^\[[^]]+\]$", text)
        if m:
            continue
        
        for word in norm(text):
            word = word.lower()
            if word in seen_words:
                continue
            if not word in entries:
                missing.append(word)
                continue
            if word.endswith("-"):
                continue
            for pron in entries[word]:
                for expanded in expand(pron):
                    # filter existing pronunciations!
                    joined = ' '.join(expanded)
                    joined = re.sub("  +", " ", joined.strip())
                    if not joined in entries[word]:
                        new_entries.add(f"{word}\t{joined}")
[" ".join(x).strip() for x in expand("JH AH1 S T")]
['JH AH1 S T', 'JH AH1 S', 'Y AH1 S T', 'Y AH1 S']
with open(outdict, "w") as outf:
    for i in sorted(new_entries):
        outf.write(i + "\n")
set(missing)
{''}