Add logit soft capping to gemma, and fix precision issues (#857)

* Add logit soft capping to gemma, and fix precision issues

Gemma was babbling nonsense - so I figured out it was due to not having logit softcapping and precision issues causing NaNs (so I implemented the softcapping and added more float32 inference). gemma-27b-it-4bit now works flawlessly (or near-flawlessly, no sliding-window attention).

* get rid of comments

* get rid of last comments (sry lol)

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
n8programs 2024-07-02 10:52:39 -04:00 committed by GitHub
parent f212b770d8
commit 1e05aef344
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,6 +20,9 @@ class ModelArgs(BaseModelArgs):
num_key_value_heads: int num_key_value_heads: int
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
attn_logit_softcapping: float = 50.0
final_logit_softcapping: float = 30.0
query_pre_attn_scalar: float = 144.0
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -39,15 +42,16 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim self.head_dim = head_dim = args.head_dim
self.scale = head_dim**-0.5 self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.attn_logit_softcapping = args.attn_logit_softcapping
self.rope = nn.RoPE( self.rope = nn.RoPE(
head_dim, head_dim,
traditional=args.rope_traditional, traditional=args.rope_traditional,
@ -61,10 +65,7 @@ class Attention(nn.Module):
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
@ -77,10 +78,25 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( queries = queries * self.scale
queries, keys, values, scale=self.scale, mask=mask
)
if self.repeats > 1:
queries = queries.reshape(
B, self.n_kv_heads, self.repeats, L, self.head_dim
)
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
scores = queries @ keys.swapaxes(-1, -2)
scores = mx.tanh(scores / self.attn_logit_softcapping)
scores *= self.attn_logit_softcapping
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, precise=True, axis=-1)
output = scores @ values
if self.repeats > 1:
output = output.reshape(B, self.n_heads, L, self.head_dim)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)
@ -119,9 +135,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
h = x + self.post_attention_layernorm(r) h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h)) r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
mx.float32
)
out = h + self.post_feedforward_layernorm(r) out = h + self.post_feedforward_layernorm(r)
return out return out
@ -165,6 +183,7 @@ class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.model_type = args.model_type self.model_type = args.model_type
self.final_logit_softcapping = args.final_logit_softcapping
self.model = GemmaModel(args) self.model = GemmaModel(args)
self.args = args self.args = args
@ -175,6 +194,8 @@ class Model(nn.Module):
): ):
out = self.model(inputs, cache) out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out) out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping
return out return out
@property @property