mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 16:58:08 +08:00
allow conversion to dlpack (#1120)
This commit is contained in:
@@ -1722,6 +1722,20 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(z.dtype, mx.int32)
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
def test_dlpack(self):
|
||||
x = mx.array(1, dtype=mx.int32)
|
||||
y = np.from_dlpack(x)
|
||||
self.assertTrue(mx.array_equal(y, x))
|
||||
|
||||
x = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
y = np.from_dlpack(x)
|
||||
self.assertTrue(mx.array_equal(y, x))
|
||||
|
||||
x = mx.arange(16).reshape(4, 4)
|
||||
x = x[::2, ::2]
|
||||
y = np.from_dlpack(x)
|
||||
self.assertTrue(mx.array_equal(y, x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user