mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
clamp for low precision
This commit is contained in:
parent
fd351850e4
commit
d2732a6478
14
t5/t5.py
14
t5/t5.py
@ -125,11 +125,12 @@ class MultiHeadAttention(nn.Module):
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scores = queries @ keys
|
||||
queries = queries.astype(mx.float32)
|
||||
scores = queries @ keys.astype(mx.float32)
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
@ -144,7 +145,11 @@ class RMSNorm(nn.Module):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
t = x.dtype
|
||||
if t == mx.float16:
|
||||
x = x.astype(mx.float32)
|
||||
x = mx.clip(x, a_min=-1e6, a_max=1e6)
|
||||
output = self._norm(x).astype(t)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
@ -290,6 +295,7 @@ class T5(nn.Module):
|
||||
y, cache = self.decoder(
|
||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||
)
|
||||
y = y.astype(mx.float32)
|
||||
if self.tie_word_embeddings:
|
||||
y *= self.model_dim**-0.5
|
||||
return self.lm_head(y), cache
|
||||
@ -409,7 +415,7 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.model)
|
||||
model, tokenizer = load_model(args.model, args.dtype)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user