This commit is contained in:
Niels 2024-04-07 19:38:45 +02:00
parent 2ae80ef87a
commit a11e1b8f3c
1 changed files with 4 additions and 1 deletions

View File

@ -18,6 +18,9 @@ from .modules.transformer import (
)
from .codebooks_patterns import DelayedPatternProvider
from huggingface_hub import PyTorchModelHubMixin
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
@ -82,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module):
class VoiceCraft(nn.Module, PyTorchModelHubMixin):
def __init__(self, args):
super().__init__()
self.args = copy.copy(args)