mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Cast qk to fp32
This commit is contained in:
@@ -139,7 +139,7 @@ def find_alignment(
|
||||
# heads * tokens * frames
|
||||
weights = mx.stack([cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1).astype(weights.dtype)
|
||||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||
weights = (weights - mean) / std
|
||||
|
@@ -85,8 +85,9 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.astype(mx.float32)
|
||||
|
||||
w = mx.softmax(qk.astype(mx.float32), axis=-1).astype(q.dtype)
|
||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
||||
out = (w @ v).transpose(0, 2, 1, 3)
|
||||
out = out.reshape(n_batch, n_ctx, n_state)
|
||||
return out, qk
|
||||
|
Reference in New Issue
Block a user