mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
clamp for low precision
This commit is contained in:
parent
fd351850e4
commit
d2732a6478
14
t5/t5.py
14
t5/t5.py
@ -125,11 +125,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = mx.concatenate([value_cache, values], axis=2)
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
scores = queries @ keys
|
queries = queries.astype(mx.float32)
|
||||||
|
scores = queries @ keys.astype(mx.float32)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(values_hat), (keys, values)
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
@ -144,7 +145,11 @@ class RMSNorm(nn.Module):
|
|||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
t = x.dtype
|
||||||
|
if t == mx.float16:
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
x = mx.clip(x, a_min=-1e6, a_max=1e6)
|
||||||
|
output = self._norm(x).astype(t)
|
||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
@ -290,6 +295,7 @@ class T5(nn.Module):
|
|||||||
y, cache = self.decoder(
|
y, cache = self.decoder(
|
||||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
)
|
)
|
||||||
|
y = y.astype(mx.float32)
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
y *= self.model_dim**-0.5
|
y *= self.model_dim**-0.5
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
@ -409,7 +415,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
model, tokenizer = load_model(args.model)
|
model, tokenizer = load_model(args.model, args.dtype)
|
||||||
|
|
||||||
if args.encode_only:
|
if args.encode_only:
|
||||||
print("[INFO] Encoding with T5...", flush=True)
|
print("[INFO] Encoding with T5...", flush=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user