Save alignment_heads

This commit is contained in:
bofenghuang
2024-01-06 17:47:46 +01:00
parent 441494b11a
commit 57111100a2
2 changed files with 4 additions and 1 deletions

View File

@@ -199,6 +199,10 @@ def torch_to_mlx(
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
return mlx_model

View File

@@ -226,7 +226,6 @@ class Whisper(nn.Module):
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
# todo: do we need this ?
if isinstance(dump, np.ndarray):
self.alignment_heads = mx.array(dump)
elif isinstance(dump, bytes):