extraction,training,data,weights

This commit is contained in:
jason-on-salt-a40
2024-03-24 19:43:37 -07:00
parent d754e9109a
commit a129883910
7 changed files with 686 additions and 176 deletions

View File

@ -54,8 +54,6 @@ class dataset(torch.utils.data.Dataset):
y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
y = y[:1]
except Exception as e:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
logging.info(f"error message: {e}")
@ -141,15 +139,15 @@ class dataset(torch.utils.data.Dataset):
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
if out['y'][0].ndim==2:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
assert out['y'][0].ndim==1, out['y'][0].shape
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])