mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Added RoPE scaling test
This commit is contained in:
parent
a415bac4f7
commit
0e5ae1d1bf
@ -556,7 +556,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
shape = (1, 3, 4)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
|
Loading…
Reference in New Issue
Block a user