mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Added linspace (#181)
* linspace ops support --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1491,6 +1491,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
def test_linspace(self):
|
||||
# Test default num = 50
|
||||
a = mx.linspace(0, 1)
|
||||
expected = mx.array(np.linspace(0, 1))
|
||||
self.assertEqualArray(a, expected)
|
||||
|
||||
# Test int32 dtype
|
||||
b = mx.linspace(0, 10, 5, mx.int64)
|
||||
expected = mx.array(np.linspace(0, 10, 5, dtype=int))
|
||||
self.assertEqualArray(b, expected)
|
||||
|
||||
# Test negative sequence with float start and stop
|
||||
c = mx.linspace(-2.7, -0.7, 7)
|
||||
expected = mx.array(np.linspace(-2.7, -0.7, 7))
|
||||
self.assertEqualArray(c, expected)
|
||||
|
||||
# Test irrational step size of 1/9
|
||||
d = mx.linspace(0, 1, 10)
|
||||
expected = mx.array(np.linspace(0, 1, 10))
|
||||
self.assertEqualArray(d, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user