Raise an exception in the rope op if input is integer (#1884)

This commit is contained in:
Angelos Katharopoulos 2025-02-19 14:43:39 -08:00 committed by GitHub
parent 1a2cb72030
commit 78ba24c37d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 0 deletions

View File

@ -348,6 +348,11 @@ array rope(
<< x.ndim() << " dimensions."; << x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!issubdtype(x.dtype(), floating)) {
std::ostringstream msg;
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (offset.size() != 1) { if (offset.size() != 1) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape() msg << "[rope] offset must be a scalar but has shape " << offset.shape()

View File

@ -158,7 +158,17 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32])
# Test raises with integer inputs
dims, _, base, scale, offset, traditional = defaults
x = (mx.random.uniform(shape=(2, T, dims)) * 10).astype(mx.int32)
with self.assertRaises(ValueError):
y = mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
)
def test_rope_with_freqs(self): def test_rope_with_freqs(self):
mx.random.seed(0)
# Check throws # Check throws
T = 4 T = 4
dims = 8 dims = 8