mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
fix alignment
This commit is contained in:
parent
9414429309
commit
f3960c2b4b
@ -142,7 +142,7 @@ def find_alignment(
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = mx.stack(
|
||||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
||||
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
|
||||
)
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
|
||||
|
@ -228,13 +228,13 @@ class Whisper(nn.Module):
|
||||
|
||||
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
||||
if isinstance(dump, np.ndarray):
|
||||
self.alignment_heads = dump
|
||||
self.alignment_heads = mx.array(dump)
|
||||
elif isinstance(dump, bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
||||
self.alignment_heads = np.asarray(mask.nonzero()).T
|
||||
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
||||
|
Loading…
Reference in New Issue
Block a user