Ensure linspace always contains start and stop (#1883)

This commit is contained in:
Angelos Katharopoulos 2025-02-19 13:53:20 -08:00 committed by GitHub
parent 344a29506e
commit 1a2cb72030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 7 deletions

View File

@ -230,16 +230,16 @@ array linspace(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (num == 1) { if (num == 1) {
return astype(array({start}), dtype, to_stream(s)); return astype(array({start}), dtype, s);
} }
array sequence = arange(0, num, float32, to_stream(s)); array t = divide(arange(0, num, float32, s), array(num - 1, float32), s);
float step = (stop - start) / (num - 1); array t_bar = subtract(array(1, float32), t, s);
return astype( return astype(
add(multiply(sequence, array(step), to_stream(s)), add(multiply(t_bar, array(start, float32), s),
array(start), multiply(t, array(stop, float32), s),
to_stream(s)), s),
dtype, dtype,
to_stream(s)); s);
} }
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {

View File

@ -2189,6 +2189,14 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.linspace(1, 10, 1)) expected = mx.array(np.linspace(1, 10, 1))
self.assertEqualArray(d, expected) self.assertEqualArray(d, expected)
# Ensure that the start and stop are always the ones provided
ranges = mx.random.normal((16, 2)).tolist()
nums = (2 + mx.random.uniform(shape=(16,)) * 10).astype(mx.uint32).tolist()
for (a, b), n in zip(ranges, nums):
d = mx.linspace(a, b, n).tolist()
self.assertEqual(d[0], a)
self.assertEqual(d[-1], b)
def test_repeat(self): def test_repeat(self):
# Setup data for the tests # Setup data for the tests
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])