import json
def load_tsv(filename):
    data = []
    with open(filename) as inf:
        for line in inf.readlines():
            parts = line.strip().split("\t")
            data.append({
                "start": float(parts[0]),
                "end": float(parts[1]),
                "word": parts[2]
            })
    return data
def slice_tsv_data(data, start, end):
    ret = []
    for datum in data:
        if type(datum["start"]) is str:
            datum["start"] = float(datum["start"])
        if type(datum["end"]) is str:
            datum["end"] = float(datum["end"])
        if datum["start"] >= start and datum["end"] <= end:
            ret.append(datum)
        elif datum["end"] > end:
            return ret
    return ret
import re
def norm_spaces(text):
    return re.sub("  +", " ", text.strip())
def clean_text(text):
    text = norm_spaces(text)
    return " ".join([x.lower().strip(".,;?!") for x in text.split(" ")])
from pathlib import Path

TSVS = Path("/Users/joregan/Playing/hsi/word_annotations/")
JSON = Path("/Users/joregan/Playing/merged_annotations/")
OUTP = Path("/Users/joregan/Playing/timed_annotations/")
if not OUTP.is_dir():
    OUTP.mkdir()
def get_indices(needle, haystack, checkpos=True):
    ret = []
    nwords = [x.lower().strip(",?.;:()") for x in needle.split(" ")]
    hwords = [x.lower().strip(",?.;:") for x in haystack.split(" ")]
    nwordspos = nwords[:-1] + [f"{nwords[-1]}'s"]
    nlen = len(nwords)

    for i in range(len(hwords)):
        if hwords[i:i+nlen] == nwords:
            ret.append((i, i+nlen))
        elif checkpos and hwords[i:i+nlen] == nwordspos:
            ret.append((i, i+nlen))
    return ret
def clean_text2(text):
    nums = {
        "60": "sixty",
        "1": "one",
        "20th": "twentieth",
        "9th": "ninth",
        "5": "five"
    }
    text = norm_spaces(text)
    words = [x.lower().strip(".,;?!") for x in text.split(" ")]
    ret = []
    for word in words:
        if word.startswith("[") and word.endswith("]"):
            continue
        elif word.startswith("{") and word.endswith("}"):
            continue
        word = nums.get(word, word)
        word = word.replace(".", " ").replace(",", " ")
        ret.append(word)
    return " ".join(ret)
MANUAL = """
hsi_5_0718_210_001	17
hsi_5_0718_210_001	18
hsi_5_0718_210_001	114
hsi_4_0717_211_003	36
hsi_4_0717_211_003	42
hsi_3_0715_210_010	89
hsi_3_0715_209_008	31
hsi_3_0715_210_011	48
hsi_4_0717_211_002	6
hsi_5_0718_210_001	49
hsi_5_0718_209_003	7
hsi_6_0718_227_002	63
hsi_5_0718_209_001	1
hsi_6_0718_210_002	102
hsi_6_0718_210_002	33
hsi_6_0718_210_002	18
hsi_6_0718_209_001	95
hsi_3_0715_209_006	18
hsi_3_0715_227_001	21
hsi_4_0717_210_001	47
hsi_3_0715_210_010	87
hsi_3_0715_210_010	15
hsi_3_0715_209_006	30
hsi_3_0715_209_006	43
hsi_6_0718_211_002	14
"""

manual_segments = {}
for line in MANUAL.split("\n"):
    if line == "":
        continue
    parts = line.split("\t")
    if not parts[0] in manual_segments:
        manual_segments[parts[0]] = []
    manual_segments[parts[0]].append(parts[1])
def get_tsv_for_segment(segment, tsv_data, filename=None, segment_id=None):
    assert "general" in segment, "Missing key 'general'"
    assert "start" in segment["general"], "Missing key 'start'"
    assert "end" in segment["general"], "Missing key 'end'"

    start = segment["general"]["start"]
    end = segment["general"]["end"]

    tsv = slice_tsv_data(tsv_data, start, end)
    tsv_words = " ".join([x["word"] for x in tsv])

    if filename and filename in manual_segments and segment_id and segment_id in manual_segments[filename]:
        return tsv

    if segment["snippet"] != tsv_words:
        cleaned_snippet = clean_text2(segment["snippet"])
        cleaned_text = clean_text2(tsv_words)

        if cleaned_snippet not in cleaned_text:
            if filename is not None and segment_id is not None:
                print(f"{filename}\t{segment_id}\t{segment['snippet']}\t{tsv_words}")
            else:
                print("🙀 mismatch:", "🖇️", segment["snippet"], "🎧", tsv_words, cleaned_text.find(cleaned_snippet))
            return []
        else:
            idxes = get_indices(cleaned_snippet, cleaned_text)
            assert len(idxes) == 1
            tsv = tsv[idxes[0][0]:idxes[0][1]]
            tsv_words = " ".join([x["word"] for x in tsv])
            cleaned_text = clean_text(tsv_words)
            assert cleaned_snippet == cleaned_text, f"🖇️ {cleaned_snippet} 🎧 {cleaned_text}"
    return tsv
def is_skippable(segment, strict=True):
    skippables = ["conversation_generic"]
    if strict:
        skippables += ["reference_imaginary"]
    if not "topic_name" in segment["high_level"]:
        if "current_topic" in segment["high_level"]:
            segment["high_level"]["topic_name"] = segment["high_level"]["current_topic"]
            del(segment["high_level"]["current_topic"])
    if segment["high_level"]["topic_name"] == "reference_unreal":
        segment["high_level"]["topic_name"] = "reference_imaginary"
    if segment["high_level"]["topic_name"] in skippables:
        return True
    elif segment["low_level"]["resolved_references"] == {}:
        return True
    else:
        return False
# that [(1, 2), (8, 9)]
def skip_overlapped_index(a, b):
    if a[0] >= b[0] and a[1] <= b[1]:
        return True
    return False

assert skip_overlapped_index((1, 2), (1, 5)) == True
assert skip_overlapped_index((1, 5), (1, 2)) == False
def prune_manual_index(indices, manual):
    ret = []
    for index in indices:
        if index[0] in manual:
            ret.append(index)
    return ret

assert prune_manual_index([(1, 3), (5, 7)], [1]) == [(1, 3)]
assert prune_manual_index([(1, 3), (5, 7)], [1, 5]) == [(1, 3), (5, 7)]
def prune_dict_for_overlap(segments):
    if len(segments.keys()) == 1:
        return segments
    for segment in segments:
        pruned = set()
        for seg2 in segments:
            if segment != seg2:
                for a in segments[segment]:
                    for b in segments[seg2]:
                        if skip_overlapped_index(a, b):
                            if a in pruned:
                                pruned.remove(a)
                        else:
                            pruned.add(a)
        segments[segment] = list(pruned)
    return segments

test = {
    "1": [(1, 3), (5, 7)],
    "2": [(9, 11)],
    "3": [(1, 4)]
}
exp = {
    "1": [(5, 7)],
    "2": [(9, 11)],
    "3": [(1, 4)]
}
assert prune_dict_for_overlap(test) == exp
def process_segment(segment, tsv_data, filename=None, segment_id=None):
    if is_skippable(segment):
        return
    tsv = get_tsv_for_segment(segment, tsv_data, filename, segment_id)
    references = segment["low_level"]["resolved_references"]
    manual_idx = segment["low_level"].get("resolved_references_indices", {})

    # these are ordered. Kinda.
    indices = {}
    for ref in references:
        indices[ref] = get_indices(ref, segment["snippet"])
        if ref in manual_idx:
            indices[ref] = prune_manual_index(indices[ref], manual_idx[ref])
    indices = prune_dict_for_overlap(indices)
    reftimes = []
    for ref in references:
        for index in indices[ref]:
            seq = tsv[index[0]:index[1]]
            if seq == []:
                continue
            start = seq[0]["start"]
            end = seq[-1]["end"]
            reftimes.append({
                "start": start,
                "end": end,
                "text": ref,
                "reference": references[ref]
            })
    segment["low_level"]["reference_times"] = reftimes
for jsonfile in JSON.glob("*.json"):
    base = jsonfile.stem
    with open(jsonfile) as jsf:
        data = json.load(jsf)
    rawtsv = load_tsv(str(TSVS / f"{base}_main.tsv"))
    outfile = OUTP / f"{base}.json"
    for seg in data:
        process_segment(data[seg], rawtsv, base, seg)
    with open(str(outfile), 'w') as f:
        json.dump(data, f, indent=2)
hsi_4_0717_227_004	8	these are doors so you can slide them out	you can take them and just walk
hsi_4_0717_227_004	27	well both the the man the wooden man and the lion there are from Africa like ancient African art	And the lion there is from ancient African art.
hsi_4_0717_227_004	28	I got them there when I was there and it was an	I got them there when I was there and it was
hsi_4_0717_227_004	39	that's why [spn]	
hsi_4_0717_227_004	51	feel bright or whatever	
print(jsonfile)
/Users/joregan/Playing/merged_annotations/hsi_6_0718_227_001.json
import json 
import csv      
    
def update_json_snippets_from_csv(json_path, csv_path, output_path):
    # Load JSON
    with open(json_path, 'r') as f:
        json_data = json.load(f)

    # Load CSV and store snippets by (start, end)
    csv_snippets = {}
    with open(csv_path, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            if len(row) != 3:
                continue
            try:
                start = round(float(row[0]), 3)
                end = round(float(row[1]), 3)
                snippet = row[2]
                csv_snippets[(start, end)] = snippet
            except ValueError:
                continue

    # Replace JSON snippets based on matching start/end
    for entry in json_data.values():
        start = round(entry['general']['start'], 3)
        end = round(entry['general']['end'], 3)
        if (start, end) in csv_snippets:
            entry['snippet'] = csv_snippets[(start, end)]

    # Save updated JSON
    with open(output_path, 'w') as f:
        json.dump(json_data, f, indent=2)

    print(f"Updated JSON saved to: {output_path}")