mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	| @@ -8,7 +8,7 @@ import mlx_tests | ||||
|  | ||||
|  | ||||
| def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): | ||||
|     N = x.shape[1] + offset | ||||
|     N = x.shape[-2] + offset | ||||
|     dtype = x.dtype | ||||
|     half_D = dims // 2 | ||||
|     positions = mx.arange(offset, N, dtype=dtype) * scale | ||||
| @@ -143,6 +143,20 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|                 ) | ||||
|                 self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|  | ||||
|         # Test transpose into rope | ||||
|         dims, _, base, scale, offset, traditional = defaults | ||||
|         x = mx.random.uniform(shape=(1, 1, 4, dims)).swapaxes(1, 2) | ||||
|         rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|         rx_fast = mx.fast.rope( | ||||
|             1.0 * x,  # multiply here to allow donation | ||||
|             dims, | ||||
|             traditional=traditional, | ||||
|             base=base, | ||||
|             scale=scale, | ||||
|             offset=offset, | ||||
|         ) | ||||
|         self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32]) | ||||
|  | ||||
|     def test_rope_with_freqs(self): | ||||
|         # Check throws | ||||
|         T = 4 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun