Fix conversion + inference errors. - Mistral (#176)

* Fix conversion + inference errors.

* wire rope_theta throuugh to nn.RoPE

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Vaibhav Srivastav
2023-12-23 03:40:25 +05:30
committed by GitHub
parent 7ae445f6c7
commit 0eaa323c10
4 changed files with 6 additions and 26 deletions

View File

@@ -41,27 +41,6 @@ class RMSNorm(nn.Module):
return self.weight * output
class RoPE(nn.RoPE):
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
super().__init__(dims, traditional)
self.base = base
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -78,7 +57,7 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = RoPE(
self.rope = nn.RoPE(
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)