mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Fix conversion + inference errors. - Mistral (#176)
* Fix conversion + inference errors. * wire rope_theta throuugh to nn.RoPE --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
7ae445f6c7
commit
0eaa323c10
@ -41,27 +41,6 @@ class RMSNorm(nn.Module):
|
|||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.RoPE):
|
|
||||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
|
||||||
super().__init__(dims, traditional)
|
|
||||||
self.base = base
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
shape = x.shape
|
|
||||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
|
||||||
N = x.shape[1] + offset
|
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
|
||||||
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
rope = (
|
|
||||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
|
||||||
)
|
|
||||||
rx = rope(costheta, sintheta, x)
|
|
||||||
|
|
||||||
return mx.reshape(rx, shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -78,7 +57,7 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = RoPE(
|
self.rope = nn.RoPE(
|
||||||
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
mlx
|
mlx>=0.0.6
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
|
@ -23,6 +23,7 @@ class ModelArgs:
|
|||||||
n_kv_heads: int
|
n_kv_heads: int
|
||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
|
rope_theta: float = 10000
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -55,7 +56,7 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -274,8 +275,8 @@ if __name__ == "__main__":
|
|||||||
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
if ntoks == 0:
|
if ntoks == 0:
|
||||||
toc = time.time()
|
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
|
toc = time.time()
|
||||||
prompt_tps = prompt.size / (toc - tic)
|
prompt_tps = prompt.size / (toc - tic)
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx
|
mlx>=0.0.6
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
numpy
|
numpy
|
||||||
|
Loading…
Reference in New Issue
Block a user