mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add mx.meshgrid (#961)
This commit is contained in:
parent
ae812350f9
commit
a1a31eed27
@ -84,6 +84,7 @@ Operations
|
|||||||
max
|
max
|
||||||
maximum
|
maximum
|
||||||
mean
|
mean
|
||||||
|
meshgrid
|
||||||
min
|
min
|
||||||
minimum
|
minimum
|
||||||
moveaxis
|
moveaxis
|
||||||
|
35
mlx/ops.cpp
35
mlx/ops.cpp
@ -696,6 +696,41 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
|||||||
return split(a, num_splits, 0, to_stream(s));
|
return split(a, num_splits, 0, to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> meshgrid(
|
||||||
|
const std::vector<array>& arrays,
|
||||||
|
bool sparse /* = false */,
|
||||||
|
std::string indexing /* = "xy" */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (indexing != "xy" && indexing != "ij") {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ndim = arrays.size();
|
||||||
|
std::vector<array> outputs;
|
||||||
|
for (int i = 0; i < ndim; ++i) {
|
||||||
|
std::vector<int> shape(ndim, 1);
|
||||||
|
shape[i] = -1;
|
||||||
|
outputs.push_back(reshape(arrays[i], std::move(shape), s));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indexing == "xy" and ndim > 1) {
|
||||||
|
std::vector<int> shape(ndim, 1);
|
||||||
|
|
||||||
|
shape[1] = arrays[0].size();
|
||||||
|
outputs[0] = reshape(arrays[0], shape, s);
|
||||||
|
shape[1] = 1;
|
||||||
|
shape[0] = arrays[1].size();
|
||||||
|
outputs[1] = reshape(arrays[1], std::move(shape), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sparse) {
|
||||||
|
outputs = broadcast_arrays(outputs, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
array clip(
|
array clip(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::optional<array>& a_min,
|
const std::optional<array>& a_min,
|
||||||
|
@ -197,6 +197,13 @@ std::vector<array> split(
|
|||||||
std::vector<array>
|
std::vector<array>
|
||||||
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** A vector of coordinate arrays from coordinate vectors. */
|
||||||
|
std::vector<array> meshgrid(
|
||||||
|
const std::vector<array>& arrays,
|
||||||
|
bool sparse = false,
|
||||||
|
std::string indexing = "xy",
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clip (limit) the values in an array.
|
* Clip (limit) the values in an array.
|
||||||
*/
|
*/
|
||||||
|
@ -2568,6 +2568,35 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The resulting stacked array.
|
array: The resulting stacked array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"meshgrid",
|
||||||
|
[](nb::args arrays_,
|
||||||
|
bool sparse,
|
||||||
|
std::string indexing,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
std::vector<array> arrays = nb::cast<std::vector<array>>(arrays_);
|
||||||
|
return meshgrid(arrays, sparse, indexing, s);
|
||||||
|
},
|
||||||
|
"arrays"_a,
|
||||||
|
"sparse"_a = false,
|
||||||
|
"indexing"_a = "xy",
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arrays (array): Input arrays.
|
||||||
|
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
|
||||||
|
array has a single non-zero element. If ``False``, a dense grid is returned.
|
||||||
|
Defaults to ``False``.
|
||||||
|
indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays.
|
||||||
|
Defaults to ``'xy'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list(array): The output arrays.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"repeat",
|
"repeat",
|
||||||
[](const array& array,
|
[](const array& array,
|
||||||
|
@ -1467,6 +1467,69 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
b = mx.array([1, 2])
|
b = mx.array([1, 2])
|
||||||
mx.concatenate([a, b], axis=0)
|
mx.concatenate([a, b], axis=0)
|
||||||
|
|
||||||
|
def test_meshgrid(self):
|
||||||
|
x = mx.array([1, 2, 3], dtype=mx.int32)
|
||||||
|
y = np.array([1, 2, 3], dtype=np.int32)
|
||||||
|
|
||||||
|
# Test single input
|
||||||
|
a_mlx = mx.meshgrid(x)
|
||||||
|
a_np = np.meshgrid(y)
|
||||||
|
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||||
|
|
||||||
|
# Test sparse
|
||||||
|
a_mlx, b_mlx, c_mlx = mx.meshgrid(x, x, x, sparse=True)
|
||||||
|
a_np, b_np, c_np = np.meshgrid(y, y, y, sparse=True)
|
||||||
|
self.assertEqualArray(a_mlx, mx.array(a_np))
|
||||||
|
self.assertEqualArray(b_mlx, mx.array(b_np))
|
||||||
|
self.assertEqualArray(c_mlx, mx.array(c_np))
|
||||||
|
|
||||||
|
# Test different lengths
|
||||||
|
x = mx.array([1, 2], dtype=mx.int32)
|
||||||
|
y = mx.array([1, 2, 3], dtype=mx.int32)
|
||||||
|
z = np.array([1, 2], dtype=np.int32)
|
||||||
|
w = np.array([1, 2, 3], dtype=np.int32)
|
||||||
|
a_mlx, b_mlx = mx.meshgrid(x, y)
|
||||||
|
a_np, b_np = np.meshgrid(z, w)
|
||||||
|
self.assertEqualArray(a_mlx, mx.array(a_np))
|
||||||
|
self.assertEqualArray(b_mlx, mx.array(b_np))
|
||||||
|
|
||||||
|
# Test empty input
|
||||||
|
x = mx.array([], dtype=mx.int32)
|
||||||
|
y = np.array([], dtype=np.int32)
|
||||||
|
a_mlx = mx.meshgrid(x)
|
||||||
|
a_np = np.meshgrid(y)
|
||||||
|
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||||
|
|
||||||
|
# Test float32 input
|
||||||
|
x = mx.array([1.1, 2.2, 3.3], dtype=mx.float32)
|
||||||
|
y = np.array([1.1, 2.2, 3.3], dtype=np.float32)
|
||||||
|
a_mlx = mx.meshgrid(x, x, x)
|
||||||
|
a_np = np.meshgrid(y, y, y)
|
||||||
|
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||||
|
self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))
|
||||||
|
self.assertEqualArray(a_mlx[2], mx.array(a_np[2]))
|
||||||
|
|
||||||
|
# Test ij indexing
|
||||||
|
x = mx.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=mx.float32)
|
||||||
|
y = np.array([1.1, 2.2, 3.3, 4.4, 5.5], dtype=np.float32)
|
||||||
|
a_mlx = mx.meshgrid(x, x, indexing="ij")
|
||||||
|
a_np = np.meshgrid(y, y, indexing="ij")
|
||||||
|
self.assertEqualArray(a_mlx[0], mx.array(a_np[0]))
|
||||||
|
self.assertEqualArray(a_mlx[1], mx.array(a_np[1]))
|
||||||
|
|
||||||
|
# Test different lengths, sparse, and ij indexing
|
||||||
|
a = mx.array([1, 2], dtype=mx.int64)
|
||||||
|
b = mx.array([1, 2, 3], dtype=mx.int64)
|
||||||
|
c = mx.array([1, 2, 3, 4], dtype=mx.int64)
|
||||||
|
x = np.array([1, 2], dtype=np.int64)
|
||||||
|
y = np.array([1, 2, 3], dtype=np.int64)
|
||||||
|
z = np.array([1, 2, 3, 4], dtype=np.int64)
|
||||||
|
a_mlx, b_mlx, c_mlx = mx.meshgrid(a, b, c, sparse=True, indexing="ij")
|
||||||
|
a_np, b_np, c_np = np.meshgrid(x, y, z, sparse=True, indexing="ij")
|
||||||
|
self.assertEqualArray(a_mlx, mx.array(a_np))
|
||||||
|
self.assertEqualArray(b_mlx, mx.array(b_np))
|
||||||
|
self.assertEqualArray(c_mlx, mx.array(c_np))
|
||||||
|
|
||||||
def test_pad(self):
|
def test_pad(self):
|
||||||
pad_width_and_values = [
|
pad_width_and_values = [
|
||||||
([(1, 1), (1, 1), (1, 1)], 0),
|
([(1, 1), (1, 1), (1, 1)], 0),
|
||||||
@ -1758,7 +1821,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.array(np.linspace(0, 1))
|
expected = mx.array(np.linspace(0, 1))
|
||||||
self.assertEqualArray(a, expected)
|
self.assertEqualArray(a, expected)
|
||||||
|
|
||||||
# Test int32 dtype
|
# Test int64 dtype
|
||||||
b = mx.linspace(0, 10, 5, mx.int64)
|
b = mx.linspace(0, 10, 5, mx.int64)
|
||||||
expected = mx.array(np.linspace(0, 10, 5, dtype=int))
|
expected = mx.array(np.linspace(0, 10, 5, dtype=int))
|
||||||
self.assertEqualArray(b, expected)
|
self.assertEqualArray(b, expected)
|
||||||
|
@ -3156,3 +3156,28 @@ TEST_CASE("test topk") {
|
|||||||
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
|
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test meshgrid") {
|
||||||
|
// Test default
|
||||||
|
auto x = array({1, 2, 3}, {3});
|
||||||
|
auto in = std::vector<array>{x};
|
||||||
|
auto out = meshgrid(in);
|
||||||
|
CHECK(array_equal(out[0], x).item<bool>());
|
||||||
|
|
||||||
|
// Test different lengths
|
||||||
|
auto y = array({4, 5}, {2});
|
||||||
|
in = std::vector<array>{x, y};
|
||||||
|
out = meshgrid(in);
|
||||||
|
auto expected_zero = array({1, 2, 3, 1, 2, 3}, {2, 3});
|
||||||
|
auto expected_one = array({4, 4, 4, 5, 5, 5}, {2, 3});
|
||||||
|
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||||
|
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||||
|
|
||||||
|
// Test sparse true
|
||||||
|
in = std::vector<array>{x, x};
|
||||||
|
out = meshgrid(in, true);
|
||||||
|
expected_zero = array({1, 2, 3}, {1, 3});
|
||||||
|
expected_one = array({1, 2, 3}, {3, 1});
|
||||||
|
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||||
|
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user