mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
Add HF
This commit is contained in:
@@ -18,6 +18,9 @@ from .modules.transformer import (
|
|||||||
)
|
)
|
||||||
from .codebooks_patterns import DelayedPatternProvider
|
from .codebooks_patterns import DelayedPatternProvider
|
||||||
|
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(
|
def top_k_top_p_filtering(
|
||||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||||
):
|
):
|
||||||
@@ -82,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceCraft(nn.Module):
|
class VoiceCraft(nn.Module, PyTorchModelHubMixin):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = copy.copy(args)
|
self.args = copy.copy(args)
|
||||||
|
Reference in New Issue
Block a user