Add class
This commit is contained in:
parent
a11e1b8f3c
commit
92b283c741
|
@ -85,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceCraft(nn.Module, PyTorchModelHubMixin):
|
class VoiceCraft(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = copy.copy(args)
|
self.args = copy.copy(args)
|
||||||
|
@ -1410,4 +1410,9 @@ class VoiceCraft(nn.Module, PyTorchModelHubMixin):
|
||||||
res = res - int(self.args.n_special)
|
res = res - int(self.args.n_special)
|
||||||
flatten_gen = flatten_gen - int(self.args.n_special)
|
flatten_gen = flatten_gen - int(self.args.n_special)
|
||||||
|
|
||||||
return res, flatten_gen[0].unsqueeze(0)
|
return res, flatten_gen[0].unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin):
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
super().__init__(config)
|
Loading…
Reference in New Issue