mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
extraction,training,data,weights
This commit is contained in:
@ -54,8 +54,6 @@ class dataset(torch.utils.data.Dataset):
|
||||
y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
||||
else:
|
||||
y = [[int(n) for n in l] for l in encos]
|
||||
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
|
||||
y = y[:1]
|
||||
except Exception as e:
|
||||
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
||||
logging.info(f"error message: {e}")
|
||||
@ -141,15 +139,15 @@ class dataset(torch.utils.data.Dataset):
|
||||
if self.args.pad_x:
|
||||
res["x"] = torch.stack(out["x"], dim=0)
|
||||
else:
|
||||
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
|
||||
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
|
||||
res["x_lens"] = torch.LongTensor(out["x_len"])
|
||||
if self.args.dynamic_batching:
|
||||
if out['y'][0].ndim==2:
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
|
||||
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
||||
else:
|
||||
assert out['y'][0].ndim==1, out['y'][0].shape
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
|
||||
else:
|
||||
res['y'] = torch.stack(out['y'], dim=0)
|
||||
res["y_lens"] = torch.LongTensor(out["y_len"])
|
||||
|
Reference in New Issue
Block a user