[Whisper] Add word timestamps and confidence scores (#201)

* Add word timestamps and confidence scores

* Create a separate forward_with_cross_qk function

* Move multiple ops from np to mlx, clean comments

* Save alignment_heads

* Cast qk to fp32

* Add test for word-level timestamps and confidence scores

* format + readme

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
bofeng huang
2024-01-07 19:01:29 +01:00
committed by GitHub
parent 25ebd36112
commit bf9926489e
7 changed files with 398 additions and 111 deletions

View File

@@ -199,6 +199,10 @@ def torch_to_mlx(
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
return mlx_model