mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix rope fallback to not upcast (#1797)
* fix rope fallback to not upcast * Update mlx/fast.cpp Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -18,7 +18,7 @@ def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): | ||||
|             -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) | ||||
|         ) | ||||
|     else: | ||||
|         inv_freqs = 1 / freqs | ||||
|         inv_freqs = (1 / freqs).astype(x.dtype) | ||||
|     theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) | ||||
|     costheta, sintheta = mx.cos(theta), mx.sin(theta) | ||||
|     if traditional: | ||||
| @@ -189,17 +189,21 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         freqs = mx.random.uniform(shape=(dims // 2,)) | ||||
|  | ||||
|         rx = rope_orig(x, dims, False, None, 1.0, 0, freqs) | ||||
|         rx_fast = mx.fast.rope( | ||||
|             x, | ||||
|             dims, | ||||
|             traditional=False, | ||||
|             base=None, | ||||
|             scale=1.0, | ||||
|             offset=0, | ||||
|             freqs=freqs, | ||||
|         ) | ||||
|         self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) | ||||
|         tolerances = {mx.float32: 1e-5, mx.float16: 1e-2} | ||||
|         for dtype in [mx.float32, mx.float16]: | ||||
|             x_ = x.astype(dtype) | ||||
|             rx = rope_orig(x_, dims, False, None, 1.0, 0, freqs) | ||||
|             rx_fast = mx.fast.rope( | ||||
|                 x_, | ||||
|                 dims, | ||||
|                 traditional=False, | ||||
|                 base=None, | ||||
|                 scale=1.0, | ||||
|                 offset=0, | ||||
|                 freqs=freqs, | ||||
|             ) | ||||
|             self.assertEqual(dtype, rx.dtype) | ||||
|             self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|  | ||||
|         # Test single vector | ||||
|         x = mx.random.uniform(shape=(1, 1, dims)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun