No reshape rope (#838)

* no reshape rope

* no reshape rope
This commit is contained in:
Awni Hannun
2024-03-18 17:03:07 -07:00
committed by GitHub
parent eaba55c9bf
commit 16546c70d8
3 changed files with 47 additions and 24 deletions

View File

@@ -44,9 +44,7 @@ class RoPE(Module):
return f"{self.dims}, traditional={self.traditional}"
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
x = mx.fast.rope(
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
@@ -54,7 +52,6 @@ class RoPE(Module):
scale=self.scale,
offset=offset,
)
return mx.reshape(x, shape)
class SinusoidalPositionalEncoding(Module):