fix alignment

This commit is contained in:
Awni Hannun 2025-02-07 09:30:20 -08:00
parent 9414429309
commit f3960c2b4b
2 changed files with 3 additions and 3 deletions

View File

@ -142,7 +142,7 @@ def find_alignment(
# heads * tokens * frames
weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
)
weights = weights[:, :, : num_frames // 2]
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)

View File

@ -228,13 +228,13 @@ class Whisper(nn.Module):
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
if isinstance(dump, np.ndarray):
self.alignment_heads = dump
self.alignment_heads = mx.array(dump)
elif isinstance(dump, bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.alignment_heads = np.asarray(mask.nonzero()).T
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
else:
raise ValueError(
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"