diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 8a884817..bb7b5238 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -40,26 +40,6 @@ class RMSNorm(nn.Module): return self.weight * output -class RoPE(nn.RoPE): - def __init__(self, dims: int, traditional: bool = False): - super().__init__(dims, traditional) - - def __call__(self, x, offset: int = 0): - shape = x.shape - x = mx.reshape(x, (-1, shape[-2], shape[-1])) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=1000000, dtype=x.dtype - ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope - ) - rx = rope(costheta, sintheta, x) - - return mx.reshape(rx, shape) - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -76,7 +56,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = RoPE(args.head_dim, traditional=True) + self.rope = nn.RoPE(args.head_dim, traditional=True, base=1000000) def __call__( self,