mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-05 08:14:34 +08:00
Save alignment_heads
This commit is contained in:
@@ -199,6 +199,10 @@ def torch_to_mlx(
|
|||||||
mlx_model = Whisper(torch_model.dims, dtype)
|
mlx_model = Whisper(torch_model.dims, dtype)
|
||||||
params = tree_map(lambda p: p.astype(dtype), params)
|
params = tree_map(lambda p: p.astype(dtype), params)
|
||||||
mlx_model.update(params)
|
mlx_model.update(params)
|
||||||
|
|
||||||
|
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
|
||||||
|
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
|
||||||
|
|
||||||
return mlx_model
|
return mlx_model
|
||||||
|
|
||||||
|
|
||||||
|
@@ -226,7 +226,6 @@ class Whisper(nn.Module):
|
|||||||
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
|
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
|
||||||
|
|
||||||
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
||||||
# todo: do we need this ?
|
|
||||||
if isinstance(dump, np.ndarray):
|
if isinstance(dump, np.ndarray):
|
||||||
self.alignment_heads = mx.array(dump)
|
self.alignment_heads = mx.array(dump)
|
||||||
elif isinstance(dump, bytes):
|
elif isinstance(dump, bytes):
|
||||||
|
Reference in New Issue
Block a user