allow conversion to dlpack (#1120)

This commit is contained in:
Awni Hannun
2024-05-16 16:11:37 -07:00
committed by GitHub
parent 8b76571896
commit 81dd33af66
4 changed files with 41 additions and 26 deletions

View File

@@ -1722,6 +1722,20 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, mx.int32)
self.assertEqual(z.item(), 3)
def test_dlpack(self):
x = mx.array(1, dtype=mx.int32)
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
x = mx.array([[1.0, 2.0], [3.0, 4.0]])
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
x = mx.arange(16).reshape(4, 4)
x = x[::2, ::2]
y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x))
if __name__ == "__main__":
unittest.main()