Fix conversion + inference errors. - Mistral (#176)

* Fix conversion + inference errors.

* wire rope_theta throuugh to nn.RoPE

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Vaibhav Srivastav
2023-12-23 03:40:25 +05:30
committed by GitHub
parent 7ae445f6c7
commit 0eaa323c10
4 changed files with 6 additions and 26 deletions

View File

@@ -23,6 +23,7 @@ class ModelArgs:
n_kv_heads: int
norm_eps: float
vocab_size: int
rope_theta: float = 10000
class RMSNorm(nn.Module):
@@ -55,7 +56,7 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
@@ -274,8 +275,8 @@ if __name__ == "__main__":
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if ntoks == 0:
toc = time.time()
mx.eval(tokens)
toc = time.time()
prompt_tps = prompt.size / (toc - tic)
tic = time.time()