From 0eaa323c108d291d9f783af7e906068336ff4cfb Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Sat, 23 Dec 2023 03:40:25 +0530 Subject: [PATCH] Fix conversion + inference errors. - Mistral (#176) * Fix conversion + inference errors. * wire rope_theta throuugh to nn.RoPE --------- Co-authored-by: Awni Hannun --- llms/llama/llama.py | 23 +---------------------- llms/llama/requirements.txt | 2 +- llms/mistral/mistral.py | 5 +++-- llms/mistral/requirements.txt | 2 +- 4 files changed, 6 insertions(+), 26 deletions(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 6a7352f3..74198c89 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -41,27 +41,6 @@ class RMSNorm(nn.Module): return self.weight * output -class RoPE(nn.RoPE): - def __init__(self, dims: int, traditional: bool = False, base: float = 10000): - super().__init__(dims, traditional) - self.base = base - - def __call__(self, x, offset: int = 0): - shape = x.shape - x = mx.reshape(x, (-1, shape[-2], shape[-1])) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=self.base, dtype=x.dtype - ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope - ) - rx = rope(costheta, sintheta, x) - - return mx.reshape(rx, shape) - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -78,7 +57,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = RoPE( + self.rope = nn.RoPE( args.head_dim, traditional=args.rope_traditional, base=args.rope_theta ) diff --git a/llms/llama/requirements.txt b/llms/llama/requirements.txt index 7111f1d4..f9a728c4 100644 --- a/llms/llama/requirements.txt +++ b/llms/llama/requirements.txt @@ -1,3 +1,3 @@ -mlx +mlx>=0.0.6 sentencepiece torch diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 688360f2..105a7988 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -23,6 +23,7 @@ class ModelArgs: n_kv_heads: int norm_eps: float vocab_size: int + rope_theta: float = 10000 class RMSNorm(nn.Module): @@ -55,7 +56,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True) + self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta) def __call__( self, @@ -274,8 +275,8 @@ if __name__ == "__main__": for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) if ntoks == 0: - toc = time.time() mx.eval(tokens) + toc = time.time() prompt_tps = prompt.size / (toc - tic) tic = time.time() diff --git a/llms/mistral/requirements.txt b/llms/mistral/requirements.txt index d775b88f..755af473 100644 --- a/llms/mistral/requirements.txt +++ b/llms/mistral/requirements.txt @@ -1,4 +1,4 @@ -mlx +mlx>=0.0.6 sentencepiece torch numpy