fixed masking to be compatible with torch 2.1.0; fix crop ops
This commit is contained in:
parent
e2e598d900
commit
0face202d8
|
@ -409,10 +409,10 @@ class VoiceCraft(nn.Module):
|
|||
.expand(-1, self.args.nhead, -1, -1)
|
||||
.reshape(bsz * self.args.nhead, 1, src_len)
|
||||
)
|
||||
# Check shapes and resize or broadcast as necessary
|
||||
# Check shapes and resize+broadcast as necessary
|
||||
if xy_attn_mask.shape != _xy_padding_mask.shape:
|
||||
# Assuming _xy_padding_mask has the correct shape and xy_attn_mask needs adjustment
|
||||
xy_attn_mask = xy_attn_mask.expand_as(_xy_padding_mask) # Example approach
|
||||
assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1) # Example approach
|
||||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||||
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
||||
|
@ -459,7 +459,7 @@ class VoiceCraft(nn.Module):
|
|||
"""
|
||||
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
||||
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
||||
y = y[:, :y_lens.max()]
|
||||
y = y[:, :, :y_lens.max()]
|
||||
assert x.ndim == 2, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
|
||||
|
|
Loading…
Reference in New Issue