This commit is contained in:
Niels 2024-04-07 19:38:45 +02:00
parent 2ae80ef87a
commit a11e1b8f3c
1 changed files with 4 additions and 1 deletions

View File

@ -18,6 +18,9 @@ from .modules.transformer import (
) )
from .codebooks_patterns import DelayedPatternProvider from .codebooks_patterns import DelayedPatternProvider
from huggingface_hub import PyTorchModelHubMixin
def top_k_top_p_filtering( def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
): ):
@ -82,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module): class VoiceCraft(nn.Module, PyTorchModelHubMixin):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = copy.copy(args) self.args = copy.copy(args)