clamp for low precision

This commit is contained in:
Awni Hannun 2023-12-18 14:25:58 -08:00
parent fd351850e4
commit d2732a6478

View File

@ -125,11 +125,12 @@ class MultiHeadAttention(nn.Module):
values = mx.concatenate([value_cache, values], axis=2) values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim] # Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys queries = queries.astype(mx.float32)
scores = queries @ keys.astype(mx.float32)
if mask is not None: if mask is not None:
scores = scores + mask.astype(scores.dtype) scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values) return self.out_proj(values_hat), (keys, values)
@ -144,7 +145,11 @@ class RMSNorm(nn.Module):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x): def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype) t = x.dtype
if t == mx.float16:
x = x.astype(mx.float32)
x = mx.clip(x, a_min=-1e6, a_max=1e6)
output = self._norm(x).astype(t)
return self.weight * output return self.weight * output
@ -290,6 +295,7 @@ class T5(nn.Module):
y, cache = self.decoder( y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
) )
y = y.astype(mx.float32)
if self.tie_word_embeddings: if self.tie_word_embeddings:
y *= self.model_dim**-0.5 y *= self.model_dim**-0.5
return self.lm_head(y), cache return self.lm_head(y), cache
@ -409,7 +415,7 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load_model(args.model) model, tokenizer = load_model(args.model, args.dtype)
if args.encode_only: if args.encode_only:
print("[INFO] Encoding with T5...", flush=True) print("[INFO] Encoding with T5...", flush=True)