mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix: remove custom rope (#470)
This commit is contained in:
parent
dc4f2e0a6b
commit
838990b33b
@ -40,26 +40,6 @@ class RMSNorm(nn.Module):
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class RoPE(nn.RoPE):
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__(dims, traditional)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -76,7 +56,7 @@ class Attention(nn.Module):
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = RoPE(args.head_dim, traditional=True)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True, base=1000000)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user