mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Raise an exception in the rope op if input is integer (#1884)
This commit is contained in:
parent
1a2cb72030
commit
78ba24c37d
@ -348,6 +348,11 @@ array rope(
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(x.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||
|
@ -158,7 +158,17 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32])
|
||||
|
||||
# Test raises with integer inputs
|
||||
dims, _, base, scale, offset, traditional = defaults
|
||||
x = (mx.random.uniform(shape=(2, T, dims)) * 10).astype(mx.int32)
|
||||
with self.assertRaises(ValueError):
|
||||
y = mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
)
|
||||
|
||||
def test_rope_with_freqs(self):
|
||||
mx.random.seed(0)
|
||||
|
||||
# Check throws
|
||||
T = 4
|
||||
dims = 8
|
||||
|
Loading…
Reference in New Issue
Block a user