From 78ba24c37d99c293f2374afbd027d80847d0ed70 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 19 Feb 2025 14:43:39 -0800 Subject: [PATCH] Raise an exception in the rope op if input is integer (#1884) --- mlx/fast.cpp | 5 +++++ python/tests/test_fast.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 8a60322db..fadf594d0 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -348,6 +348,11 @@ array rope( << x.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } + if (!issubdtype(x.dtype(), floating)) { + std::ostringstream msg; + msg << "[rope] Input must be a floating type but got " << x.dtype() << "."; + throw std::invalid_argument(msg.str()); + } if (offset.size() != 1) { std::ostringstream msg; msg << "[rope] offset must be a scalar but has shape " << offset.shape() diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index fd9a87c9c..2aa8b067c 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -158,7 +158,17 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32]) + # Test raises with integer inputs + dims, _, base, scale, offset, traditional = defaults + x = (mx.random.uniform(shape=(2, T, dims)) * 10).astype(mx.int32) + with self.assertRaises(ValueError): + y = mx.fast.rope( + x, dims, traditional=traditional, base=base, scale=scale, offset=offset + ) + def test_rope_with_freqs(self): + mx.random.seed(0) + # Check throws T = 4 dims = 8