fix: handle linspace function when num is 1 (#602)

* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
This commit is contained in:
Avikant Srivastava 2024-02-05 00:33:49 +05:30 committed by GitHub
parent 4fd2fb84a6
commit 11a9fd40f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 0 deletions

View File

@ -148,6 +148,9 @@ array linspace(
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
throw std::invalid_argument(msg.str());
}
if (num == 1) {
return astype(array({start}), dtype, to_stream(s));
}
array sequence = arange(0, num, float32, to_stream(s));
float step = (stop - start) / (num - 1);
return astype(

View File

@ -1661,6 +1661,11 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.linspace(0, 1, 10))
self.assertEqualArray(d, expected)
# Test num equal to 1
d = mx.linspace(1, 10, 1)
expected = mx.array(np.linspace(1, 10, 1))
self.assertEqualArray(d, expected)
def test_repeat(self):
# Setup data for the tests
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])