mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 16:21:14 +08:00
fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1 * add comment * fix test case * remove breakpoint
This commit is contained in:
parent
4fd2fb84a6
commit
11a9fd40f0
@ -148,6 +148,9 @@ array linspace(
|
|||||||
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
|
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
if (num == 1) {
|
||||||
|
return astype(array({start}), dtype, to_stream(s));
|
||||||
|
}
|
||||||
array sequence = arange(0, num, float32, to_stream(s));
|
array sequence = arange(0, num, float32, to_stream(s));
|
||||||
float step = (stop - start) / (num - 1);
|
float step = (stop - start) / (num - 1);
|
||||||
return astype(
|
return astype(
|
||||||
|
@ -1661,6 +1661,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.array(np.linspace(0, 1, 10))
|
expected = mx.array(np.linspace(0, 1, 10))
|
||||||
self.assertEqualArray(d, expected)
|
self.assertEqualArray(d, expected)
|
||||||
|
|
||||||
|
# Test num equal to 1
|
||||||
|
d = mx.linspace(1, 10, 1)
|
||||||
|
expected = mx.array(np.linspace(1, 10, 1))
|
||||||
|
self.assertEqualArray(d, expected)
|
||||||
|
|
||||||
def test_repeat(self):
|
def test_repeat(self):
|
||||||
# Setup data for the tests
|
# Setup data for the tests
|
||||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||||
|
Loading…
Reference in New Issue
Block a user