diff --git a/inference_tts_scale.py b/inference_tts_scale.py index 2ebb78c..9915b22 100644 --- a/inference_tts_scale.py +++ b/inference_tts_scale.py @@ -98,7 +98,9 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token gen_sample = audio_tokenizer.decode( [(gen_frames, None)] ) - + #Empty cuda cache between runs + if torch.cuda.is_available(): + torch.cuda.empty_cache() # return return concat_sample, gen_sample @@ -187,4 +189,4 @@ if __name__ == "__main__": seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav" torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr) - torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr) \ No newline at end of file + torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr)