From 1a2cb72030d73bd1fcc0cddc550c8c5491f39867 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 19 Feb 2025 13:53:20 -0800 Subject: [PATCH] Ensure linspace always contains start and stop (#1883) --- mlx/ops.cpp | 14 +++++++------- python/tests/test_ops.py | 8 ++++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 514a244b0..4e147487d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -230,16 +230,16 @@ array linspace( throw std::invalid_argument(msg.str()); } if (num == 1) { - return astype(array({start}), dtype, to_stream(s)); + return astype(array({start}), dtype, s); } - array sequence = arange(0, num, float32, to_stream(s)); - float step = (stop - start) / (num - 1); + array t = divide(arange(0, num, float32, s), array(num - 1, float32), s); + array t_bar = subtract(array(1, float32), t, s); return astype( - add(multiply(sequence, array(step), to_stream(s)), - array(start), - to_stream(s)), + add(multiply(t_bar, array(start, float32), s), + multiply(t, array(stop, float32), s), + s), dtype, - to_stream(s)); + s); } array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9c349bb1c..43f1b3335 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2189,6 +2189,14 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.linspace(1, 10, 1)) self.assertEqualArray(d, expected) + # Ensure that the start and stop are always the ones provided + ranges = mx.random.normal((16, 2)).tolist() + nums = (2 + mx.random.uniform(shape=(16,)) * 10).astype(mx.uint32).tolist() + for (a, b), n in zip(ranges, nums): + d = mx.linspace(a, b, n).tolist() + self.assertEqual(d[0], a) + self.assertEqual(d[-1], b) + def test_repeat(self): # Setup data for the tests data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])