mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00
some improvements to speedup alignment computation in MLX Whisper
This commit is contained in:
parent
e2e5478da5
commit
9414429309
@ -134,9 +134,7 @@ def find_alignment(
|
|||||||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||||
# consider only the logits associated with predicting text
|
# consider only the logits associated with predicting text
|
||||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
|
||||||
sampled_logits.dtype
|
|
||||||
)
|
|
||||||
text_token_probs = mx.take_along_axis(
|
text_token_probs = mx.take_along_axis(
|
||||||
token_probs, mx.array(text_tokens)[:, None], axis=1
|
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
@ -147,7 +145,8 @@ def find_alignment(
|
|||||||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
||||||
)
|
)
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
|
||||||
|
weights = weights.astype(mx.float32)
|
||||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||||
weights = (weights - mean) / std
|
weights = (weights - mean) / std
|
||||||
|
@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
w = mx.softmax(qk, axis=-1, precise=True)
|
w = mx.softmax(qk, axis=-1, precise=True)
|
||||||
out = (w @ v).transpose(0, 2, 1, 3)
|
out = (w @ v).transpose(0, 2, 1, 3)
|
||||||
out = out.reshape(n_batch, n_ctx, n_state)
|
out = out.reshape(n_batch, n_ctx, n_state)
|
||||||
return out, qk.astype(mx.float32)
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
@ -228,13 +228,13 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
|
||||||
if isinstance(dump, np.ndarray):
|
if isinstance(dump, np.ndarray):
|
||||||
self.alignment_heads = mx.array(dump)
|
self.alignment_heads = dump
|
||||||
elif isinstance(dump, bytes):
|
elif isinstance(dump, bytes):
|
||||||
array = np.frombuffer(
|
array = np.frombuffer(
|
||||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||||
).copy()
|
).copy()
|
||||||
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
|
||||||
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
|
self.alignment_heads = np.asarray(mask.nonzero()).T
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
|
||||||
|
Loading…
Reference in New Issue
Block a user