mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix RoPE bug + minor updates
This commit is contained in:
parent
a3ecda22fe
commit
ec11763527
@ -41,6 +41,26 @@ class RMSNorm(nn.Module):
|
|||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
|
class RoPE(nn.RoPE):
|
||||||
|
def __init__(self, dims: int, traditional: bool = False):
|
||||||
|
super().__init__(dims, traditional)
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
shape = x.shape
|
||||||
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
|
N = x.shape[1] + offset
|
||||||
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
|
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = (
|
||||||
|
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||||
|
)
|
||||||
|
rx = rope(costheta, sintheta, x)
|
||||||
|
|
||||||
|
return mx.reshape(rx, shape)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -57,7 +77,7 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
self.rope = RoPE(args.head_dim, traditional=True)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -126,7 +146,10 @@ class MOEFeedForward(nn.Module):
|
|||||||
|
|
||||||
gates = self.gate(x)
|
gates = self.gate(x)
|
||||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
scores = mx.softmax(
|
||||||
|
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||||
|
axis=-1,
|
||||||
|
).astype(gates.dtype)
|
||||||
|
|
||||||
y = []
|
y = []
|
||||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||||
@ -182,8 +205,9 @@ class Mixtral(nn.Module):
|
|||||||
h = self.tok_embeddings(inputs)
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if h.shape[1] > 1:
|
T = h.shape[1]
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -192,7 +216,7 @@ class Mixtral(nn.Module):
|
|||||||
for e, layer in enumerate(self.layers):
|
for e, layer in enumerate(self.layers):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h, cache[e] = layer(h, mask, cache[e])
|
||||||
|
|
||||||
return self.output(self.norm(h)), cache
|
return self.output(self.norm(h[:, T - 1 : T, :])), cache
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
@ -278,7 +302,7 @@ if __name__ == "__main__":
|
|||||||
"--temp",
|
"--temp",
|
||||||
help="The sampling temperature.",
|
help="The sampling temperature.",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=0.0,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user