Merge 8814295e98
into ddfef83331
This commit is contained in:
commit
cdede0d73a
159
edit_utils.py
159
edit_utils.py
|
@ -1,49 +1,120 @@
|
|||
def get_span(orig, new, editType):
|
||||
orig_list = orig.split(" ")
|
||||
new_list = new.split(" ")
|
||||
|
||||
flag = False # this indicate whether the actual edit follow the specified editType
|
||||
if editType == "deletion":
|
||||
assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}"
|
||||
diff = len(orig_list) - len(new_list)
|
||||
for i, (o, n) in enumerate(zip(orig_list, new_list)):
|
||||
if o != n: # assume the index of the first different word is the starting index of the orig_span
|
||||
|
||||
orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part
|
||||
new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part
|
||||
flag = True
|
||||
break
|
||||
import re
|
||||
|
||||
|
||||
elif editType == "insertion":
|
||||
assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}"
|
||||
diff = len(new_list) - len(orig_list)
|
||||
for i, (o, n) in enumerate(zip(orig_list, new_list)):
|
||||
if o != n: # insertion is just the opposite of deletion
|
||||
new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same
|
||||
orig_span = [i-1, i]
|
||||
flag = True
|
||||
break
|
||||
def levenshtein_distance(word1, word2):
|
||||
len1, len2 = len(word1), len(word2)
|
||||
# Initialize a matrix to store the edit distances, operations, and positions
|
||||
dp = [[(0, "", []) for _ in range(len2 + 1)] for _ in range(len1 + 1)]
|
||||
|
||||
elif editType == "substitution":
|
||||
new_span = []
|
||||
orig_span = []
|
||||
for i, (o, n) in enumerate(zip(orig_list, new_list)):
|
||||
if o != n:
|
||||
new_span = [i]
|
||||
orig_span = [i]
|
||||
break
|
||||
assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}"
|
||||
for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])):
|
||||
if o != n:
|
||||
new_span.append(len(new_list) - j -1)
|
||||
orig_span.append(len(orig_list) - j - 1)
|
||||
flag = True
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"editType unknown: {editType}")
|
||||
# Initialize the first row and column
|
||||
for i in range(len1 + 1):
|
||||
dp[i][0] = (i, "d" * i)
|
||||
for j in range(len2 + 1):
|
||||
dp[0][j] = (j, "i" * j)
|
||||
|
||||
if not flag:
|
||||
raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}")
|
||||
# Fill in the rest of the matrix
|
||||
for i in range(1, len1 + 1):
|
||||
for j in range(1, len2 + 1):
|
||||
cost = 0 if word1[i - 1] == word2[j - 1] else 1
|
||||
# Minimum of deletion, insertion, or substitution
|
||||
deletion = dp[i - 1][j][0] + 1
|
||||
insertion = dp[i][j - 1][0] + 1
|
||||
substitution = dp[i - 1][j - 1][0] + cost
|
||||
min_dist = min(deletion, insertion, substitution)
|
||||
|
||||
return orig_span, new_span
|
||||
# which operation led to the minimum distance
|
||||
if min_dist == deletion:
|
||||
operation = dp[i - 1][j][1] + "d"
|
||||
elif min_dist == insertion:
|
||||
operation = dp[i][j - 1][1] + "i"
|
||||
else:
|
||||
operation = dp[i - 1][j - 1][1] + ("s" if cost else "=")
|
||||
|
||||
dp[i][j] = (min_dist, operation)
|
||||
|
||||
# min edit distance, list of operations, positions of operations
|
||||
return dp[len1][len2][0], dp[len1][len2][1]
|
||||
|
||||
|
||||
def extract_words(sentence):
|
||||
words = re.findall(r"\b\w+\b", sentence)
|
||||
return words
|
||||
|
||||
# edge cases for spans of deletion, insertion, substitution
|
||||
def handle_delete(start, end, orig, new):
|
||||
orig.append([start, end - 1])
|
||||
new.append([start - 1, start])
|
||||
|
||||
def handle_insert(start, end, orig, new):
|
||||
temp_new = [start - 1, start]
|
||||
orig.append(temp_new)
|
||||
new.append(orig[-1])
|
||||
orig[-1], new[-1] = new[-1], temp_new
|
||||
|
||||
def handle_substitute(start, end, orig, new):
|
||||
orig.append([start, end - 1])
|
||||
new.append([start, end - 1])
|
||||
|
||||
# editing the last index of the sentence is another edge case
|
||||
def handle_last_operation(prev_op, start, end, orig, new):
|
||||
if prev_op == 'd':
|
||||
handle_delete(start, end, orig, new)
|
||||
elif prev_op == 'i':
|
||||
handle_insert(start, end, orig, new)
|
||||
elif prev_op == 's':
|
||||
handle_substitute(start, end, orig, new)
|
||||
|
||||
# adjust spans according to edge case expected output
|
||||
def adjust_last_span(operations, orig, new):
|
||||
if operations[-1] == 'd':
|
||||
new[-1] = [new[-1][0] - 1, new[-1][1] - 1]
|
||||
orig[-1] = [orig[-1][0] - 1, orig[-1][0] - 1]
|
||||
elif operations[-1] == 'i':
|
||||
new[-1] = [new[-1][0] - 1, new[-1][1] - 1]
|
||||
orig[-1] = [orig[-1][0] - 1, orig[-1][0]]
|
||||
|
||||
def get_spans(operations):
|
||||
orig = []
|
||||
new = []
|
||||
prev_op = None
|
||||
start = 0
|
||||
end = 0
|
||||
for i, op in enumerate(operations):
|
||||
# prevent span duplication of sequential edits of the same type
|
||||
if op != '=':
|
||||
if op != prev_op:
|
||||
if prev_op:
|
||||
handle_last_operation(prev_op, start, end, orig, new)
|
||||
prev_op = op
|
||||
start = i
|
||||
end = i + 1
|
||||
else:
|
||||
if prev_op:
|
||||
handle_last_operation(prev_op, start, end, orig, new)
|
||||
prev_op = None
|
||||
start = end
|
||||
# edge case of last operation
|
||||
if prev_op:
|
||||
handle_last_operation(prev_op, start, end, orig, new)
|
||||
adjust_last_span(operations, orig, new)
|
||||
return orig, new
|
||||
|
||||
def get_edits(operations):
|
||||
used_edits = []
|
||||
prev_op = ''
|
||||
for op in operations:
|
||||
if op == 'i' and prev_op != 'i':
|
||||
used_edits.append("insertion")
|
||||
elif op == 'd' and prev_op != 'd':
|
||||
used_edits.append("deletion")
|
||||
elif op == 's' and prev_op != 's':
|
||||
used_edits.append("substitution")
|
||||
prev_op = op
|
||||
return used_edits
|
||||
|
||||
def parse_edit(orig_transcript, trgt_transcript):
|
||||
word1 = extract_words(orig_transcript)
|
||||
word2 = extract_words(trgt_transcript)
|
||||
distance, operations = levenshtein_distance(word1, word2)
|
||||
orig_span, new_span = get_spans(operations)
|
||||
return operations, orig_span, new_span
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue