mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
@@ -44,9 +44,7 @@ class RoPE(Module):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
x = mx.fast.rope(
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
@@ -54,7 +52,6 @@ class RoPE(Module):
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
return mx.reshape(x, shape)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
|
Reference in New Issue
Block a user