From ec1176352746e66152daa8558779e2c59ab7a51e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 21:45:25 -0800 Subject: [PATCH] fix RoPE bug + minor updates --- mixtral/mixtral.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 59848219..16a9eec8 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -41,6 +41,26 @@ class RMSNorm(nn.Module): return self.weight * output +class RoPE(nn.RoPE): + def __init__(self, dims: int, traditional: bool = False): + super().__init__(dims, traditional) + + def __call__(self, x, offset: int = 0): + shape = x.shape + x = mx.reshape(x, (-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, base=1000000, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return mx.reshape(rx, shape) + + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -57,7 +77,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True) + self.rope = RoPE(args.head_dim, traditional=True) def __call__( self, @@ -126,7 +146,10 @@ class MOEFeedForward(nn.Module): gates = self.gate(x) inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] - scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) + scores = mx.softmax( + mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), + axis=-1, + ).astype(gates.dtype) y = [] for xt, st, it in zip(x, scores, inds.tolist()): @@ -182,8 +205,9 @@ class Mixtral(nn.Module): h = self.tok_embeddings(inputs) mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) mask = mask.astype(h.dtype) if cache is None: @@ -192,7 +216,7 @@ class Mixtral(nn.Module): for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.output(self.norm(h)), cache + return self.output(self.norm(h[:, T - 1 : T, :])), cache class Tokenizer: @@ -278,7 +302,7 @@ if __name__ == "__main__": "--temp", help="The sampling temperature.", type=float, - default=1.0, + default=0.0, ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")