mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Save alignment_heads
This commit is contained in:
@@ -199,6 +199,10 @@ def torch_to_mlx(
|
||||
mlx_model = Whisper(torch_model.dims, dtype)
|
||||
params = tree_map(lambda p: p.astype(dtype), params)
|
||||
mlx_model.update(params)
|
||||
|
||||
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
|
||||
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
|
||||
|
||||
return mlx_model
|
||||
|
||||
|
||||
|
@@ -226,7 +226,6 @@ class Whisper(nn.Module):
|
||||
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
|
||||
|
||||
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
||||
# todo: do we need this ?
|
||||
if isinstance(dump, np.ndarray):
|
||||
self.alignment_heads = mx.array(dump)
|
||||
elif isinstance(dump, bytes):
|
||||
|
Reference in New Issue
Block a user