diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index b5a89e308..1e934befd 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -84,6 +84,7 @@ Operations max maximum mean + meshgrid min minimum moveaxis diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9f997ebaf..004421e96 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -696,6 +696,41 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) { return split(a, num_splits, 0, to_stream(s)); } +std::vector meshgrid( + const std::vector& arrays, + bool sparse /* = false */, + std::string indexing /* = "xy" */, + StreamOrDevice s /* = {} */) { + if (indexing != "xy" && indexing != "ij") { + throw std::invalid_argument( + "[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'."); + } + + auto ndim = arrays.size(); + std::vector outputs; + for (int i = 0; i < ndim; ++i) { + std::vector shape(ndim, 1); + shape[i] = -1; + outputs.push_back(reshape(arrays[i], std::move(shape), s)); + } + + if (indexing == "xy" and ndim > 1) { + std::vector shape(ndim, 1); + + shape[1] = arrays[0].size(); + outputs[0] = reshape(arrays[0], shape, s); + shape[1] = 1; + shape[0] = arrays[1].size(); + outputs[1] = reshape(arrays[1], std::move(shape), s); + } + + if (!sparse) { + outputs = broadcast_arrays(outputs, s); + } + + return outputs; +} + array clip( const array& a, const std::optional& a_min, diff --git a/mlx/ops.h b/mlx/ops.h index 0118a2f35..4b18635e2 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -197,6 +197,13 @@ std::vector split( std::vector split(const array& a, const std::vector& indices, StreamOrDevice s = {}); +/** A vector of coordinate arrays from coordinate vectors. */ +std::vector meshgrid( + const std::vector& arrays, + bool sparse = false, + std::string indexing = "xy", + StreamOrDevice s = {}); + /** * Clip (limit) the values in an array. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1c08b55a5..a4ca29d6e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2568,6 +2568,35 @@ void init_ops(nb::module_& m) { Returns: array: The resulting stacked array. )pbdoc"); + m.def( + "meshgrid", + [](nb::args arrays_, + bool sparse, + std::string indexing, + StreamOrDevice s) { + std::vector arrays = nb::cast>(arrays_); + return meshgrid(arrays, sparse, indexing, s); + }, + "arrays"_a, + "sparse"_a = false, + "indexing"_a = "xy", + "stream"_a = nb::none(), + nb::sig( + "def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Generate multidimensional coordinate grids from 1-D coordinate arrays + + Args: + arrays (array): Input arrays. + sparse (bool, optional): If ``True``, a sparse grid is returned in which each output + array has a single non-zero element. If ``False``, a dense grid is returned. + Defaults to ``False``. + indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays. + Defaults to ``'xy'``. + + Returns: + list(array): The output arrays. + )pbdoc"); m.def( "repeat", [](const array& array, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4906b5b55..e170c59fc 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1467,6 +1467,69 @@ class TestOps(mlx_tests.MLXTestCase): b = mx.array([1, 2]) mx.concatenate([a, b], axis=0) + def test_meshgrid(self): + x = mx.array([1, 2, 3], dtype=mx.int32) + y = np.array([1, 2, 3], dtype=np.int32) + + # Test single input + a_mlx = mx.meshgrid(x) + a_np = np.meshgrid(y) + self.assertEqualArray(a_mlx[0], mx.array(a_np[0])) + + # Test sparse + a_mlx, b_mlx, c_mlx = mx.meshgrid(x, x, x, sparse=True) + a_np, b_np, c_np = np.meshgrid(y, y, y, sparse=True) + self.assertEqualArray(a_mlx, mx.array(a_np)) + self.assertEqualArray(b_mlx, mx.array(b_np)) + self.assertEqualArray(c_mlx, mx.array(c_np)) + + # Test different lengths + x = mx.array([1, 2], dtype=mx.int32) + y = mx.array([1, 2, 3], dtype=mx.int32) + z = np.array([1, 2], dtype=np.int32) + w = np.array([1, 2, 3], dtype=np.int32) + a_mlx, b_mlx = mx.meshgrid(x, y) + a_np, b_np = np.meshgrid(z, w) + self.assertEqualArray(a_mlx, mx.array(a_np)) + self.assertEqualArray(b_mlx, mx.array(b_np)) + + # Test empty input + x = mx.array([], dtype=mx.int32) + y = np.array([], dtype=np.int32) + a_mlx = mx.meshgrid(x) + a_np = np.meshgrid(y) + self.assertEqualArray(a_mlx[0], mx.array(a_np[0])) + + # Test float32 input + x = mx.array([1.1, 2.2, 3.3], dtype=mx.float32) + y = np.array([1.1, 2.2, 3.3], dtype=np.float32) + a_mlx = mx.meshgrid(x, x, x) + a_np = np.meshgrid(y, y, y) + self.assertEqualArray(a_mlx[0], mx.array(a_np[0])) + self.assertEqualArray(a_mlx[1], mx.array(a_np[1])) + self.assertEqualArray(a_mlx[2], mx.array(a_np[2])) + + # Test ij indexing + x = mx.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=mx.float32) + y = np.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=np.float32) + a_mlx = mx.meshgrid(x, x, indexing="ij") + a_np = np.meshgrid(y, y, indexing="ij") + self.assertEqualArray(a_mlx[0], mx.array(a_np[0])) + self.assertEqualArray(a_mlx[1], mx.array(a_np[1])) + + # Test different lengths, sparse, and ij indexing + a = mx.array([1, 2], dtype=mx.int64) + b = mx.array([1, 2, 3], dtype=mx.int64) + c = mx.array([1, 2, 3, 4], dtype=mx.int64) + x = np.array([1, 2], dtype=np.int64) + y = np.array([1, 2, 3], dtype=np.int64) + z = np.array([1, 2, 3, 4], dtype=np.int64) + a_mlx, b_mlx, c_mlx = mx.meshgrid(a, b, c, sparse=True, indexing="ij") + a_np, b_np, c_np = np.meshgrid(x, y, z, sparse=True, indexing="ij") + self.assertEqualArray(a_mlx, mx.array(a_np)) + self.assertEqualArray(b_mlx, mx.array(b_np)) + self.assertEqualArray(c_mlx, mx.array(c_np)) + def test_pad(self): pad_width_and_values = [ ([(1, 1), (1, 1), (1, 1)], 0), @@ -1758,7 +1821,7 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.linspace(0, 1)) self.assertEqualArray(a, expected) - # Test int32 dtype + # Test int64 dtype b = mx.linspace(0, 10, 5, mx.int64) expected = mx.array(np.linspace(0, 10, 5, dtype=int)) self.assertEqualArray(b, expected) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 545526c2e..ead089cec 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3156,3 +3156,28 @@ TEST_CASE("test topk") { CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item()); } } + +TEST_CASE("test meshgrid") { + // Test default + auto x = array({1, 2, 3}, {3}); + auto in = std::vector{x}; + auto out = meshgrid(in); + CHECK(array_equal(out[0], x).item()); + + // Test different lengths + auto y = array({4, 5}, {2}); + in = std::vector{x, y}; + out = meshgrid(in); + auto expected_zero = array({1, 2, 3, 1, 2, 3}, {2, 3}); + auto expected_one = array({4, 4, 4, 5, 5, 5}, {2, 3}); + CHECK(array_equal(out[0], expected_zero).item()); + CHECK(array_equal(out[1], expected_one).item()); + + // Test sparse true + in = std::vector{x, x}; + out = meshgrid(in, true); + expected_zero = array({1, 2, 3}, {1, 3}); + expected_one = array({1, 2, 3}, {3, 1}); + CHECK(array_equal(out[0], expected_zero).item()); + CHECK(array_equal(out[1], expected_one).item()); +} \ No newline at end of file