fix RoPE bug + minor updates

This commit is contained in:
Awni Hannun 2023-12-14 21:45:25 -08:00
parent a3ecda22fe
commit ec11763527

View File

@ -41,6 +41,26 @@ class RMSNorm(nn.Module):
return self.weight * output return self.weight * output
class RoPE(nn.RoPE):
def __init__(self, dims: int, traditional: bool = False):
super().__init__(dims, traditional)
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -57,7 +77,7 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True) self.rope = RoPE(args.head_dim, traditional=True)
def __call__( def __call__(
self, self,
@ -126,7 +146,10 @@ class MOEFeedForward(nn.Module):
gates = self.gate(x) gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = [] y = []
for xt, st, it in zip(x, scores, inds.tolist()): for xt, st, it in zip(x, scores, inds.tolist()):
@ -182,8 +205,9 @@ class Mixtral(nn.Module):
h = self.tok_embeddings(inputs) h = self.tok_embeddings(inputs)
mask = None mask = None
if h.shape[1] > 1: T = h.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype) mask = mask.astype(h.dtype)
if cache is None: if cache is None:
@ -192,7 +216,7 @@ class Mixtral(nn.Module):
for e, layer in enumerate(self.layers): for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e]) h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache return self.output(self.norm(h[:, T - 1 : T, :])), cache
class Tokenizer: class Tokenizer:
@ -278,7 +302,7 @@ if __name__ == "__main__":
"--temp", "--temp",
help="The sampling temperature.", help="The sampling temperature.",
type=float, type=float,
default=1.0, default=0.0,
) )
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")