mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Ensure linspace always contains start and stop (#1883)
This commit is contained in:
parent
344a29506e
commit
1a2cb72030
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -230,16 +230,16 @@ array linspace(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (num == 1) {
|
if (num == 1) {
|
||||||
return astype(array({start}), dtype, to_stream(s));
|
return astype(array({start}), dtype, s);
|
||||||
}
|
}
|
||||||
array sequence = arange(0, num, float32, to_stream(s));
|
array t = divide(arange(0, num, float32, s), array(num - 1, float32), s);
|
||||||
float step = (stop - start) / (num - 1);
|
array t_bar = subtract(array(1, float32), t, s);
|
||||||
return astype(
|
return astype(
|
||||||
add(multiply(sequence, array(step), to_stream(s)),
|
add(multiply(t_bar, array(start, float32), s),
|
||||||
array(start),
|
multiply(t, array(stop, float32), s),
|
||||||
to_stream(s)),
|
s),
|
||||||
dtype,
|
dtype,
|
||||||
to_stream(s));
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
|
@ -2189,6 +2189,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.array(np.linspace(1, 10, 1))
|
expected = mx.array(np.linspace(1, 10, 1))
|
||||||
self.assertEqualArray(d, expected)
|
self.assertEqualArray(d, expected)
|
||||||
|
|
||||||
|
# Ensure that the start and stop are always the ones provided
|
||||||
|
ranges = mx.random.normal((16, 2)).tolist()
|
||||||
|
nums = (2 + mx.random.uniform(shape=(16,)) * 10).astype(mx.uint32).tolist()
|
||||||
|
for (a, b), n in zip(ranges, nums):
|
||||||
|
d = mx.linspace(a, b, n).tolist()
|
||||||
|
self.assertEqual(d[0], a)
|
||||||
|
self.assertEqual(d[-1], b)
|
||||||
|
|
||||||
def test_repeat(self):
|
def test_repeat(self):
|
||||||
# Setup data for the tests
|
# Setup data for the tests
|
||||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||||
|
Loading…
Reference in New Issue
Block a user