mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-11 15:06:42 +08:00
Add mx.meshgrid (#961)
This commit is contained in:
@@ -2568,6 +2568,35 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The resulting stacked array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"meshgrid",
|
||||
[](nb::args arrays_,
|
||||
bool sparse,
|
||||
std::string indexing,
|
||||
StreamOrDevice s) {
|
||||
std::vector<array> arrays = nb::cast<std::vector<array>>(arrays_);
|
||||
return meshgrid(arrays, sparse, indexing, s);
|
||||
},
|
||||
"arrays"_a,
|
||||
"sparse"_a = false,
|
||||
"indexing"_a = "xy",
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
||||
|
||||
Args:
|
||||
arrays (array): Input arrays.
|
||||
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
||||
array has a single non-zero element. If ``False``, a dense grid is returned.
|
||||
Defaults to ``False``.
|
||||
indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays.
|
||||
Defaults to ``'xy'``.
|
||||
|
||||
Returns:
|
||||
list(array): The output arrays.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"repeat",
|
||||
[](const array& array,
|
||||
|
||||
@@ -1467,6 +1467,69 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
b = mx.array([1, 2])
|
||||
mx.concatenate([a, b], axis=0)
|
||||
|
||||
def test_meshgrid(self):
|
||||
x = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
y = np.array([1, 2, 3], dtype=np.int32)
|
||||
|
||||
# Test single input
|
||||
a_mlx = mx.meshgrid(x)
|
||||
a_np = np.meshgrid(y)
|
||||
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||
|
||||
# Test sparse
|
||||
a_mlx, b_mlx, c_mlx = mx.meshgrid(x, x, x, sparse=True)
|
||||
a_np, b_np, c_np = np.meshgrid(y, y, y, sparse=True)
|
||||
self.assertEqualArray(a_mlx, mx.array(a_np))
|
||||
self.assertEqualArray(b_mlx, mx.array(b_np))
|
||||
self.assertEqualArray(c_mlx, mx.array(c_np))
|
||||
|
||||
# Test different lengths
|
||||
x = mx.array([1, 2], dtype=mx.int32)
|
||||
y = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
z = np.array([1, 2], dtype=np.int32)
|
||||
w = np.array([1, 2, 3], dtype=np.int32)
|
||||
a_mlx, b_mlx = mx.meshgrid(x, y)
|
||||
a_np, b_np = np.meshgrid(z, w)
|
||||
self.assertEqualArray(a_mlx, mx.array(a_np))
|
||||
self.assertEqualArray(b_mlx, mx.array(b_np))
|
||||
|
||||
# Test empty input
|
||||
x = mx.array([], dtype=mx.int32)
|
||||
y = np.array([], dtype=np.int32)
|
||||
a_mlx = mx.meshgrid(x)
|
||||
a_np = np.meshgrid(y)
|
||||
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||
|
||||
# Test float32 input
|
||||
x = mx.array([1.1, 2.2, 3.3], dtype=mx.float32)
|
||||
y = np.array([1.1, 2.2, 3.3], dtype=np.float32)
|
||||
a_mlx = mx.meshgrid(x, x, x)
|
||||
a_np = np.meshgrid(y, y, y)
|
||||
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||
self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))
|
||||
self.assertEqualArray(a_mlx[2], mx.array(a_np[2]))
|
||||
|
||||
# Test ij indexing
|
||||
x = mx.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=mx.float32)
|
||||
y = np.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=np.float32)
|
||||
a_mlx = mx.meshgrid(x, x, indexing="ij")
|
||||
a_np = np.meshgrid(y, y, indexing="ij")
|
||||
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||
self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))
|
||||
|
||||
# Test different lengths, sparse, and ij indexing
|
||||
a = mx.array([1, 2], dtype=mx.int64)
|
||||
b = mx.array([1, 2, 3], dtype=mx.int64)
|
||||
c = mx.array([1, 2, 3, 4], dtype=mx.int64)
|
||||
x = np.array([1, 2], dtype=np.int64)
|
||||
y = np.array([1, 2, 3], dtype=np.int64)
|
||||
z = np.array([1, 2, 3, 4], dtype=np.int64)
|
||||
a_mlx, b_mlx, c_mlx = mx.meshgrid(a, b, c, sparse=True, indexing="ij")
|
||||
a_np, b_np, c_np = np.meshgrid(x, y, z, sparse=True, indexing="ij")
|
||||
self.assertEqualArray(a_mlx, mx.array(a_np))
|
||||
self.assertEqualArray(b_mlx, mx.array(b_np))
|
||||
self.assertEqualArray(c_mlx, mx.array(c_np))
|
||||
|
||||
def test_pad(self):
|
||||
pad_width_and_values = [
|
||||
([(1, 1), (1, 1), (1, 1)], 0),
|
||||
@@ -1758,7 +1821,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = mx.array(np.linspace(0, 1))
|
||||
self.assertEqualArray(a, expected)
|
||||
|
||||
# Test int32 dtype
|
||||
# Test int64 dtype
|
||||
b = mx.linspace(0, 10, 5, mx.int64)
|
||||
expected = mx.array(np.linspace(0, 10, 5, dtype=int))
|
||||
self.assertEqualArray(b, expected)
|
||||
|
||||
Reference in New Issue
Block a user