extraction,training,data,weights

This commit is contained in:
jason-on-salt-a40
2024-03-24 19:43:37 -07:00
parent d754e9109a
commit a129883910
7 changed files with 686 additions and 176 deletions

View File

@ -504,7 +504,7 @@ class VoiceCraft(nn.Module):
ntokens = []
top10acc = []
for k, (logit, target) in enumerate(zip(logits, targets)):
loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None))
loss.append(F.cross_entropy(logit, target, reduction='mean'))
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
ntokens.append(len(logit))
@ -988,6 +988,8 @@ class VoiceCraft(nn.Module):
for jj in range(1,self.args.n_codebooks):
logits_adjust[jj][eog_inference] = -10000
logits_adjust[jj][self.args.empty_token] = -10000
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
logits_adjust[0][eog_inference] = -10000
##################### silence repetition handling #####################
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
if logits_adjust[0, prev_token] < 0:
@ -1237,6 +1239,8 @@ class VoiceCraft(nn.Module):
for jj in range(1,self.args.n_codebooks):
logits_adjust[:,jj,eog_inference] = -10000
logits_adjust[:,jj,self.args.empty_token] = -10000
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
logits_adjust[:,:,eog_inference] = -10000
##################### silence repetition handling #####################
for b in range(batch_size):
prev_token = prev_tokens[b]