mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added implementation for Scaled RoPE. (#261)
* Added scale for RoPE * Ran pre-commit * Added RoPE scaling test * Added docstring for scale parameter * Modified docstrings
This commit is contained in:
@@ -749,7 +749,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
shape = (1, 3, 4)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
|
Reference in New Issue
Block a user