fix overlapping margins

This commit is contained in:
Pranay Gosar 2024-04-17 19:01:39 -05:00
parent 73fac7c460
commit 8814295e98
1 changed files with 23 additions and 7 deletions

View File

@ -38,6 +38,7 @@
"# hyperparameters for inference\n",
"left_margin = 0.08\n",
"right_margin = 0.08\n",
"sub_amount = 0.01\n",
"codec_audio_sr = 16000\n",
"codec_sr = 50\n",
"top_k = 0\n",
@ -128,13 +129,15 @@
"source": [
"# propose what do you want the target modified transcript to be\n",
"orig_transcript = \"But when I had approached so near to them which the sense deceives, Lost not by distance any of its marks,\"\n",
"target_transcript = \"But when I had approached so near which the sense deceives, Lost not by distance any of its marks,\" # deletes \"to them\"\n",
"target_transcript = \"But I did approached so near to them which the sense deceives, Lost not by distance any of its marks,\"\n",
"\n",
"# from edit_utils import parse_edit, get_edits\n",
"\n",
"# run the script to turn user input to the format that the model can take\n",
"operations, orig_span, new_span = parse_edit(orig_transcript, target_transcript)\n",
"\n",
"used_edits = get_edits(operations)\n",
"print(used_edits)\n",
"print(used_edits) \n",
"\n",
"def process_span(span):\n",
" if span[0] > span[1]:\n",
@ -158,15 +161,28 @@
" starting_intervals.append(start)\n",
" ending_intervals.append(end)\n",
"\n",
"print(\"intervals: \", starting_intervals, ending_intervals)\n",
"\n",
"info = torchaudio.info(audio_fn)\n",
"audio_dur = info.num_frames / info.sample_rate\n",
"morphed_span = [(max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur))\n",
" for start, end in zip(starting_intervals, ending_intervals)] # in seconds\n",
"\n",
"def resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount):\n",
" while True:\n",
" morphed_span = [(max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur))\n",
" for start, end in zip(starting_intervals, ending_intervals)] # in seconds\n",
" mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]\n",
" # Check for overlap\n",
" overlapping = any(a[1] >= b[0] for a, b in zip(mask_interval, mask_interval[1:]))\n",
" if not overlapping:\n",
" break\n",
" \n",
" # Reduce margins\n",
" left_margin -= sub_amount\n",
" right_margin -= sub_amount\n",
" \n",
" return mask_interval\n",
"\n",
"\n",
"# span in codec frames\n",
"mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]\n",
"mask_interval = resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount)\n",
"mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n",
"\n",
"# load model, tokenizer, and other necessary files\n",