mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Fix conversion + inference errors. - Mistral (#176)
* Fix conversion + inference errors. * wire rope_theta throuugh to nn.RoPE --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
7ae445f6c7
commit
0eaa323c10
@@ -23,6 +23,7 @@ class ModelArgs:
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
rope_theta: float = 10000
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -55,7 +56,7 @@ class Attention(nn.Module):
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -274,8 +275,8 @@ if __name__ == "__main__":
|
||||
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
if ntoks == 0:
|
||||
toc = time.time()
|
||||
mx.eval(tokens)
|
||||
toc = time.time()
|
||||
prompt_tps = prompt.size / (toc - tic)
|
||||
tic = time.time()
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx
|
||||
mlx>=0.0.6
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
||||
|
Reference in New Issue
Block a user