mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
extraction,training,data,weights
This commit is contained in:
@ -504,7 +504,7 @@ class VoiceCraft(nn.Module):
|
||||
ntokens = []
|
||||
top10acc = []
|
||||
for k, (logit, target) in enumerate(zip(logits, targets)):
|
||||
loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None))
|
||||
loss.append(F.cross_entropy(logit, target, reduction='mean'))
|
||||
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
||||
ntokens.append(len(logit))
|
||||
|
||||
@ -988,6 +988,8 @@ class VoiceCraft(nn.Module):
|
||||
for jj in range(1,self.args.n_codebooks):
|
||||
logits_adjust[jj][eog_inference] = -10000
|
||||
logits_adjust[jj][self.args.empty_token] = -10000
|
||||
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
||||
logits_adjust[0][eog_inference] = -10000
|
||||
##################### silence repetition handling #####################
|
||||
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
||||
if logits_adjust[0, prev_token] < 0:
|
||||
@ -1237,6 +1239,8 @@ class VoiceCraft(nn.Module):
|
||||
for jj in range(1,self.args.n_codebooks):
|
||||
logits_adjust[:,jj,eog_inference] = -10000
|
||||
logits_adjust[:,jj,self.args.empty_token] = -10000
|
||||
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
||||
logits_adjust[:,:,eog_inference] = -10000
|
||||
##################### silence repetition handling #####################
|
||||
for b in range(batch_size):
|
||||
prev_token = prev_tokens[b]
|
||||
|
Reference in New Issue
Block a user