Cast qk to fp32

This commit is contained in:
bofenghuang
2024-01-06 18:59:10 +01:00
parent 57111100a2
commit 9cdf6388c7
2 changed files with 3 additions and 2 deletions

View File

@@ -139,7 +139,7 @@ def find_alignment(
# heads * tokens * frames
weights = mx.stack([cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads])
weights = weights[:, :, : num_frames // 2]
weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1).astype(weights.dtype)
weights = mx.softmax(weights * qk_scale, axis=-1)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std

View File

@@ -85,8 +85,9 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk.astype(mx.float32), axis=-1).astype(q.dtype)
w = mx.softmax(qk, axis=-1).astype(q.dtype)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out, qk