Add mx.meshgrid (#961)

This commit is contained in:
Abe Leininger
2024-04-09 14:43:08 -04:00
committed by GitHub
parent ae812350f9
commit a1a31eed27
6 changed files with 161 additions and 1 deletions

View File

@@ -1467,6 +1467,69 @@ class TestOps(mlx_tests.MLXTestCase):
b = mx.array([1, 2])
mx.concatenate([a, b], axis=0)
def test_meshgrid(self):
x = mx.array([1, 2, 3], dtype=mx.int32)
y = np.array([1, 2, 3], dtype=np.int32)
# Test single input
a_mlx = mx.meshgrid(x)
a_np = np.meshgrid(y)
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
# Test sparse
a_mlx, b_mlx, c_mlx = mx.meshgrid(x, x, x, sparse=True)
a_np, b_np, c_np = np.meshgrid(y, y, y, sparse=True)
self.assertEqualArray(a_mlx, mx.array(a_np))
self.assertEqualArray(b_mlx, mx.array(b_np))
self.assertEqualArray(c_mlx, mx.array(c_np))
# Test different lengths
x = mx.array([1, 2], dtype=mx.int32)
y = mx.array([1, 2, 3], dtype=mx.int32)
z = np.array([1, 2], dtype=np.int32)
w = np.array([1, 2, 3], dtype=np.int32)
a_mlx, b_mlx = mx.meshgrid(x, y)
a_np, b_np = np.meshgrid(z, w)
self.assertEqualArray(a_mlx, mx.array(a_np))
self.assertEqualArray(b_mlx, mx.array(b_np))
# Test empty input
x = mx.array([], dtype=mx.int32)
y = np.array([], dtype=np.int32)
a_mlx = mx.meshgrid(x)
a_np = np.meshgrid(y)
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
# Test float32 input
x = mx.array([1.1, 2.2, 3.3], dtype=mx.float32)
y = np.array([1.1, 2.2, 3.3], dtype=np.float32)
a_mlx = mx.meshgrid(x, x, x)
a_np = np.meshgrid(y, y, y)
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))
self.assertEqualArray(a_mlx[2], mx.array(a_np[2]))
# Test ij indexing
x = mx.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=mx.float32)
y = np.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=np.float32)
a_mlx = mx.meshgrid(x, x, indexing="ij")
a_np = np.meshgrid(y, y, indexing="ij")
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))
# Test different lengths, sparse, and ij indexing
a = mx.array([1, 2], dtype=mx.int64)
b = mx.array([1, 2, 3], dtype=mx.int64)
c = mx.array([1, 2, 3, 4], dtype=mx.int64)
x = np.array([1, 2], dtype=np.int64)
y = np.array([1, 2, 3], dtype=np.int64)
z = np.array([1, 2, 3, 4], dtype=np.int64)
a_mlx, b_mlx, c_mlx = mx.meshgrid(a, b, c, sparse=True, indexing="ij")
a_np, b_np, c_np = np.meshgrid(x, y, z, sparse=True, indexing="ij")
self.assertEqualArray(a_mlx, mx.array(a_np))
self.assertEqualArray(b_mlx, mx.array(b_np))
self.assertEqualArray(c_mlx, mx.array(c_np))
def test_pad(self):
pad_width_and_values = [
([(1, 1), (1, 1), (1, 1)], 0),
@@ -1758,7 +1821,7 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.linspace(0, 1))
self.assertEqualArray(a, expected)
# Test int32 dtype
# Test int64 dtype
b = mx.linspace(0, 10, 5, mx.int64)
expected = mx.array(np.linspace(0, 10, 5, dtype=int))
self.assertEqualArray(b, expected)