higher clipping, remove non-helpful casts

This commit is contained in:
Awni Hannun 2023-12-18 14:36:07 -08:00
parent d2732a6478
commit 05a8464d78

View File

@ -125,12 +125,12 @@ class MultiHeadAttention(nn.Module):
values = mx.concatenate([value_cache, values], axis=2) values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim] # Dimensions are [batch x num heads x sequence x hidden dim]
queries = queries.astype(mx.float32) queries = queries
scores = queries @ keys.astype(mx.float32) scores = queries @ keys
if mask is not None: if mask is not None:
scores = scores + mask.astype(scores.dtype) scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values) return self.out_proj(values_hat), (keys, values)
@ -148,7 +148,7 @@ class RMSNorm(nn.Module):
t = x.dtype t = x.dtype
if t == mx.float16: if t == mx.float16:
x = x.astype(mx.float32) x = x.astype(mx.float32)
x = mx.clip(x, a_min=-1e6, a_max=1e6) x = mx.clip(x, a_min=-1e9, a_max=1e9)
output = self._norm(x).astype(t) output = self._norm(x).astype(t)
return self.weight * output return self.weight * output
@ -295,7 +295,6 @@ class T5(nn.Module):
y, cache = self.decoder( y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
) )
y = y.astype(mx.float32)
if self.tie_word_embeddings: if self.tie_word_embeddings:
y *= self.model_dim**-0.5 y *= self.model_dim**-0.5
return self.lm_head(y), cache return self.lm_head(y), cache