mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
higher clipping, remove non-helpful casts
This commit is contained in:
parent
d2732a6478
commit
05a8464d78
9
t5/t5.py
9
t5/t5.py
@ -125,12 +125,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = mx.concatenate([value_cache, values], axis=2)
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
queries = queries.astype(mx.float32)
|
queries = queries
|
||||||
scores = queries @ keys.astype(mx.float32)
|
scores = queries @ keys
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype)
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(values_hat), (keys, values)
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class RMSNorm(nn.Module):
|
|||||||
t = x.dtype
|
t = x.dtype
|
||||||
if t == mx.float16:
|
if t == mx.float16:
|
||||||
x = x.astype(mx.float32)
|
x = x.astype(mx.float32)
|
||||||
x = mx.clip(x, a_min=-1e6, a_max=1e6)
|
x = mx.clip(x, a_min=-1e9, a_max=1e9)
|
||||||
output = self._norm(x).astype(t)
|
output = self._norm(x).astype(t)
|
||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
@ -295,7 +295,6 @@ class T5(nn.Module):
|
|||||||
y, cache = self.decoder(
|
y, cache = self.decoder(
|
||||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
)
|
)
|
||||||
y = y.astype(mx.float32)
|
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
y *= self.model_dim**-0.5
|
y *= self.model_dim**-0.5
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
Loading…
Reference in New Issue
Block a user