Process annotated JSON
To get times for GPT-generated references
EG = "/Users/joregan/Playing/merged_annotations/hsi_4_0717_209_002.json"
import json
with open(EG) as inf:
data = json.load(inf)
SAMPLE = """
{
"2": {
"general": {
"id": 2,
"start": 5.970287890159987,
"end": 9.551266621561616
},
"high_level": {
"topic_name": "reference_real",
"current_topic": "vase",
"topic_change": false,
"topic_duration_id": 0,
"spatial_reference": "fourth vase",
"referenced_object": [
"vase"
]
},
"low_level": {
"resolved_references": {
"the fourth vase": "vase",
"it": "vase"
},
"spatial_relationships": [],
"resolved_adverbs": []
},
"snippet": "But I'm wondering about the fourth vase. Where is it? Did you break it?"
}
}
"""
"resolved_references_indices": {
"it": [2, 4]
}
def load_tsv(filename):
data = []
with open(filename) as inf:
for line in inf.readlines():
parts = line.strip().split("\t")
data.append({
"start": float(parts[0]),
"end": float(parts[1]),
"word": parts[2]
})
return data
def slice_tsv_data(data, start, end):
ret = []
for datum in data:
if type(datum["start"]) is str:
datum["start"] = float(datum["start"])
if type(datum["end"]) is str:
datum["end"] = float(datum["end"])
if datum["start"] >= start and datum["end"] <= end:
ret.append(datum)
elif datum["end"] > end:
return ret
return ret
def clean_text(text):
return " ".join([x.lower().strip(".,;?!") for x in text.split(" ")])
from pathlib import Path
TSVS = Path("/Users/joregan/Playing/hsi/word_annotations/")
JSON = Path("/Users/joregan/Playing/merged_annotations/")
OUTP = Path("/Users/joregan/Playing/timed_annotations/")
if not OUTP.is_dir():
OUTP.mkdir()
def get_indices(needle, haystack, checkpos=True):
ret = []
nwords = [x.lower().strip(",?.;:") for x in needle.split(" ")]
hwords = [x.lower().strip(",?.;:") for x in haystack.split(" ")]
nwordspos = nwords[:-1] + [f"{nwords[-1]}'s"]
nlen = len(nwords)
for i in range(len(hwords)):
if hwords[i:i+nlen] == nwords:
ret.append((i, i+nlen))
elif checkpos and hwords[i:i+nlen] == nwordspos:
ret.append((i, i+nlen))
return ret
def get_tsv_for_segment(segment, tsv_data):
assert "general" in segment, "Missing key 'general'"
assert "start" in segment["general"], "Missing key 'start'"
assert "end" in segment["general"], "Missing key 'end'"
start = segment["general"]["start"]
end = segment["general"]["end"]
tsv = slice_tsv_data(tsv_data, start, end)
tsv_words = " ".join([x["word"] for x in tsv])
if segment["snippet"] != tsv_words:
cleaned_snippet = clean_text(segment["snippet"])
cleaned_text = clean_text(tsv_words)
if cleaned_snippet not in cleaned_text:
print("🙀 mismatch:", "🖇️", segment["snippet"], "🎧", tsv_words, cleaned_text.find(cleaned_snippet))
return []
else:
idxes = get_indices(cleaned_snippet, cleaned_text)
assert len(idxes) == 1
tsv = tsv[idxes[0][0]:idxes[0][1]]
tsv_words = " ".join([x["word"] for x in tsv])
cleaned_text = clean_text(tsv_words)
assert cleaned_snippet == cleaned_text, f"🖇️ {cleaned_snippet} 🎧 {cleaned_text}"
return tsv
def is_skippable(segment, strict=True):
skippables = ["conversation_generic"]
if strict:
skippables += ["reference_imaginary"]
if segment["high_level"]["topic_name"] == "reference_unreal":
segment["high_level"]["topic_name"] = "reference_imaginary"
if segment["high_level"]["topic_name"] in skippables:
return True
elif segment["low_level"]["resolved_references"] == {}:
return True
else:
return False
# that [(1, 2), (8, 9)]
def skip_overlapped_index(a, b):
if a[0] >= b[0] and a[1] <= b[1]:
return True
return False
assert skip_overlapped_index((1, 2), (1, 5)) == True
assert skip_overlapped_index((1, 5), (1, 2)) == False
def prune_manual_index(indices, manual):
ret = []
for index in indices:
if index[0] in manual:
ret.append(index)
return ret
assert prune_manual_index([(1, 3), (5, 7)], [1]) == [(1, 3)]
assert prune_manual_index([(1, 3), (5, 7)], [1, 5]) == [(1, 3), (5, 7)]
def prune_dict_for_overlap(segments):
if len(segments.keys()) == 1:
return segments
for segment in segments:
pruned = set()
for seg2 in segments:
if segment != seg2:
for a in segments[segment]:
for b in segments[seg2]:
if skip_overlapped_index(a, b):
if a in pruned:
pruned.remove(a)
else:
pruned.add(a)
segments[segment] = list(pruned)
return segments
test = {
"1": [(1, 3), (5, 7)],
"2": [(9, 11)],
"3": [(1, 4)]
}
exp = {
"1": [(5, 7)],
"2": [(9, 11)],
"3": [(1, 4)]
}
assert prune_dict_for_overlap(test) == exp
def process_segment(segment, tsv_data):
if is_skippable(segment):
return
tsv = get_tsv_for_segment(segment, tsv_data)
references = segment["low_level"]["resolved_references"]
manual_idx = segment["low_level"].get("resolved_references_indices", {})
# these are ordered. Kinda.
indices = {}
for ref in references:
indices[ref] = get_indices(ref, segment["snippet"])
if ref in manual_idx:
indices[ref] = prune_manual_index(indices[ref], manual_idx[ref])
indices = prune_dict_for_overlap(indices)
reftimes = []
for ref in references:
for index in indices[ref]:
seq = tsv[index[0]:index[1]]
start = seq[0]["start"]
end = seq[-1]["end"]
reftimes.append({
"start": start,
"end": end,
"text": ref,
"reference": references[ref]
})
segment["low_level"]["reference_times"] = reftimes
for jsonfile in JSON.glob("*.json"):
base = jsonfile.stem
with open(jsonfile) as jsf:
data = json.load(jsf)
rawtsv = load_tsv(str(TSVS / f"{base}_main.tsv"))
outfile = OUTP / f"{base}.json"
for seg in data:
process_segment(data[seg], rawtsv)
with open(str(outfile), 'w') as f:
json.dump(data, f, indent=2)