fixed masking to be compatible with torch 2.1.0; fix crop ops

This commit is contained in:
jason-on-salt-a40 2024-04-05 13:33:03 -07:00
parent e2e598d900
commit 0face202d8
1 changed files with 4 additions and 4 deletions

View File

@ -409,10 +409,10 @@ class VoiceCraft(nn.Module):
.expand(-1, self.args.nhead, -1, -1)
.reshape(bsz * self.args.nhead, 1, src_len)
)
# Check shapes and resize or broadcast as necessary
# Check shapes and resize+broadcast as necessary
if xy_attn_mask.shape != _xy_padding_mask.shape:
# Assuming _xy_padding_mask has the correct shape and xy_attn_mask needs adjustment
xy_attn_mask = xy_attn_mask.expand_as(_xy_padding_mask) # Example approach
assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1) # Example approach
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask)
@ -459,7 +459,7 @@ class VoiceCraft(nn.Module):
"""
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
y = y[:, :y_lens.max()]
y = y[:, :, :y_lens.max()]
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape