diff --git a/models/voicecraft.py b/models/voicecraft.py index e3852c4..ab3cf37 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -409,10 +409,10 @@ class VoiceCraft(nn.Module): .expand(-1, self.args.nhead, -1, -1) .reshape(bsz * self.args.nhead, 1, src_len) ) - # Check shapes and resize or broadcast as necessary + # Check shapes and resize+broadcast as necessary if xy_attn_mask.shape != _xy_padding_mask.shape: - # Assuming _xy_padding_mask has the correct shape and xy_attn_mask needs adjustment - xy_attn_mask = xy_attn_mask.expand_as(_xy_padding_mask) # Example approach + assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}" + xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1) # Example approach xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask) @@ -459,7 +459,7 @@ class VoiceCraft(nn.Module): """ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x - y = y[:, :y_lens.max()] + y = y[:, :, :y_lens.max()] assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape