Ensure linspace always contains start and stop (#1883)

This commit is contained in:
Angelos Katharopoulos
2025-02-19 13:53:20 -08:00
committed by GitHub
parent 344a29506e
commit 1a2cb72030
2 changed files with 15 additions and 7 deletions

View File

@@ -2189,6 +2189,14 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.linspace(1, 10, 1))
self.assertEqualArray(d, expected)
# Ensure that the start and stop are always the ones provided
ranges = mx.random.normal((16, 2)).tolist()
nums = (2 + mx.random.uniform(shape=(16,)) * 10).astype(mx.uint32).tolist()
for (a, b), n in zip(ranges, nums):
d = mx.linspace(a, b, n).tolist()
self.assertEqual(d[0], a)
self.assertEqual(d[-1], b)
def test_repeat(self):
# Setup data for the tests
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])