fix conversion to array (#1070)

This commit is contained in:
Awni Hannun
2024-05-06 16:02:49 -07:00
committed by GitHub
parent 6992498e7a
commit 9814a2ae12
4 changed files with 105 additions and 54 deletions

View File

@@ -1710,6 +1710,13 @@ class TestArray(mlx_tests.MLXTestCase):
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self):
x = mx.array(1)
y = np.array(2, dtype=np.int32)
z = x + y
self.assertEqual(z.dtype, mx.int32)
self.assertEqual(z.item(), 3)
if __name__ == "__main__":
unittest.main()