From 11a9fd40f034a948b89360159f871a927d1fc034 Mon Sep 17 00:00:00 2001 From: Avikant Srivastava Date: Mon, 5 Feb 2024 00:33:49 +0530 Subject: [PATCH] fix: handle linspace function when num is 1 (#602) * fix: handle linspace function when num is 1 * add comment * fix test case * remove breakpoint --- mlx/ops.cpp | 3 +++ python/tests/test_ops.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index cb15a9570..3f784dd9d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -148,6 +148,9 @@ array linspace( msg << "[linspace] number of samples, " << num << ", must be non-negative."; throw std::invalid_argument(msg.str()); } + if (num == 1) { + return astype(array({start}), dtype, to_stream(s)); + } array sequence = arange(0, num, float32, to_stream(s)); float step = (stop - start) / (num - 1); return astype( diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2d5366a6d..add06c729 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1661,6 +1661,11 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.linspace(0, 1, 10)) self.assertEqualArray(d, expected) + # Test num equal to 1 + d = mx.linspace(1, 10, 1) + expected = mx.array(np.linspace(1, 10, 1)) + self.assertEqualArray(d, expected) + def test_repeat(self): # Setup data for the tests data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])