def approx_match(time_a, time_b, slippage=(0.01 * 6)):
    return abs(time_a - time_b) <= slippage
SAMPLE = "2442204240010034621"
from pathlib import Path
TEXT = Path("/Users/joregan/Playing/kbw2v")
PHONES = Path("/Users/joregan/Playing/rd_phonetic")
class Chunk:
    def __init__(self, chunk):
        self.text = chunk['text']
        self.start = chunk['timestamp'][0]
        self.end = chunk['timestamp'][1]
    
    def __repr__(self) -> str:
        return f"[{self.text} ({self.start}, {self.end})]"

class SimpleMerge(Chunk):
    def __init__(self, left: Chunk, right: Chunk):
        self.text = left.text
        self.phone = right.text
        self.start = left.start
        self.end = left.end
        self.diff_start = left.start - right.start
        self.diff_end = left.end - right.end

    def exact_length(self):
        return self.diff_start == 0 and self.diff_end == 0

    def __repr__(self) -> str:
        return f"[{self.text} :: {self.phone} ({self.start}, {self.end})]"

class ComplexMerge(Chunk):
    def __init__(self, left, right):
        if type(left) == list:
            self.left_chunks = left
        elif left == None:
            self.left_chunks == []
        else:
            self.left_chunks = [left]
        if type(right) == list:
            self.right_chunks = right
        elif right == None:
            self.right_chunks == []
        else:
            self.right_chunks = [right]

        self.start = self.get_start()
        self.end = self.get_end()

        self.text = " ".join([x.text for x in self.left_chunks])
        self.phone = " ".join([x.text for x in self.right_chunks])

    def get_start(self):
        if self.left_chunks == [] and self.right_chunks == []:
            return None
        if self.left_chunks == []:
            return self.right_chunks[0].start
        if self.right_chunks == []:
            return self.left_chunks[0].start
        if self.left_chunks[0].start < self.right_chunks[0].start:
            return self.left_chunks[0].start
        else:
            return self.right_chunks[0].start

    def get_end(self):
        if self.left_chunks == [] and self.right_chunks == []:
            return None
        if self.left_chunks == []:
            return self.right_chunks[-1].end
        if self.right_chunks == []:
            return self.left_chunks[-1].end
        if self.left_chunks[-1].end > self.right_chunks[-1].end:
            return self.left_chunks[-1].end
        else:
            return self.right_chunks[-1].end

    def __repr__(self) -> str:
        return f"[{self.text} :: {self.phone} ({self.start}, {self.end})]"

        
class PhoneChunk(Chunk):
    def __init__(self, chunk):
        self.text = chunk.text
        self.start = chunk.start
        self.end = chunk.end
class WordChunk(Chunk):
    def __init__(self, chunk):
        self.text = chunk.text
        self.start = chunk.start
        self.end = chunk.end
def create_merges(text_chunks, phone_chunks):
    merged = []

    start = True

    pci = 0
    tci = 0

    iteration = 0

    while pci < len(phone_chunks) and tci < len(text_chunks):
        iteration += 1

        am_start = approx_match(text_chunks[tci].start, phone_chunks[pci].start)
        am_end = approx_match(text_chunks[tci].end, phone_chunks[pci].end)

        if am_start and am_end:
            merged.append(SimpleMerge(text_chunks[tci], phone_chunks[pci]))

        elif am_start:
            cur_text = [text_chunks[tci]]
            cur_phone = [phone_chunks[pci]]
            if phone_chunks[pci].end < text_chunks[tci].end:
                while not approx_match(text_chunks[tci].end, phone_chunks[pci].end) and phone_chunks[pci].end < text_chunks[tci].end:
                    pci += 1
                    if pci >= len(phone_chunks):
                        break
                    cur_phone.append(phone_chunks[pci])
                merged.append(ComplexMerge(cur_text, cur_phone))
            else:
                while not approx_match(text_chunks[tci].end, phone_chunks[pci].end) and text_chunks[tci].end < phone_chunks[pci].end:
                    tci += 1
                    if tci >= len(text_chunks):
                        break
                    cur_text.append(text_chunks[tci])
                merged.append(ComplexMerge(cur_text, cur_phone))
        else:
            if phone_chunks[pci].end < text_chunks[tci].start:
                while phone_chunks[pci].end < text_chunks[tci].start:
                    pci += 1
                    if pci >= len(phone_chunks):
                        break
                    merged.append(PhoneChunk(phone_chunks[pci]))
            elif text_chunks[tci].end < phone_chunks[pci].start:
                while text_chunks[tci].end < phone_chunks[pci].start:
                    tci += 1
                    if tci >= len(text_chunks):
                        break
                    merged.append(WordChunk(text_chunks[tci]))
            else:
                print("else", text_chunks[tci], phone_chunks[pci])
        tci += 1
        pci += 1

    return merged
def print_merges(filename, source, merged):
    with open(filename, "w") as outfile:
        for merge in merged:
            if type(merge) == PhoneChunk:
                outfile.write(f'{source}\t<UNALIGNED PHONE>\t{merge.text}\t{merge.start}\t{merge.end}\n')
            elif type(merge) == WordChunk:
                outfile.write(f'{source}\t<UNALIGNED WORD>\t{merge.text}\t{merge.start}\t{merge.end}\n')
            else:
                outfile.write(f'{source}\t{merge.text}\t{merge.phone}\t{merge.start}\t{merge.end}\n')
import json

for file in PHONES.glob("*.json"):
    text_chunks = []
    phone_chunks = []
    
    with open(TEXT / f"{file.name}") as text:
        text_json = json.load(text)
    for text_chunk in text_json['chunks']:
        text_chunks.append(Chunk(text_chunk))

    with open(file) as phones:
        phone_json = json.load(phones)
    for phone_chunk in phone_json['chunks']:
        phone_chunks.append(Chunk(phone_chunk))

    merged = create_merges(text_chunks, phone_chunks)

    print_merges(f"/Users/joregan/Playing/rd_tpalign/{file.stem}.tsv", file.stem, merged)