mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
This commit is contained in:
parent
e5851e52b1
commit
e28b57e371
@ -85,6 +85,7 @@ Operations
|
|||||||
sqrt
|
sqrt
|
||||||
square
|
square
|
||||||
squeeze
|
squeeze
|
||||||
|
stack
|
||||||
stop_gradient
|
stop_gradient
|
||||||
subtract
|
subtract
|
||||||
sum
|
sum
|
||||||
|
31
mlx/ops.cpp
31
mlx/ops.cpp
@ -574,11 +574,11 @@ array concatenate(
|
|||||||
shape[ax] += a.shape(ax);
|
shape[ax] += a.shape(ax);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Promote all the arrays to the same type
|
||||||
|
auto dtype = result_type(arrays);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays);
|
||||||
arrays[0].dtype(),
|
|
||||||
std::make_unique<Concatenate>(to_stream(s), ax),
|
|
||||||
arrays);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array concatenate(
|
array concatenate(
|
||||||
@ -591,6 +591,29 @@ array concatenate(
|
|||||||
return concatenate(flat_inputs, 0, s);
|
return concatenate(flat_inputs, 0, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Stack arrays along a new axis */
|
||||||
|
array stack(
|
||||||
|
const std::vector<array>& arrays,
|
||||||
|
int axis,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (arrays.empty()) {
|
||||||
|
throw std::invalid_argument("No arrays provided for stacking");
|
||||||
|
}
|
||||||
|
if (!is_same_shape(arrays)) {
|
||||||
|
throw std::invalid_argument("All arrays must have the same shape");
|
||||||
|
}
|
||||||
|
int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1);
|
||||||
|
std::vector<array> new_arrays;
|
||||||
|
new_arrays.reserve(arrays.size());
|
||||||
|
for (auto& a : arrays) {
|
||||||
|
new_arrays.emplace_back(expand_dims(a, normalized_axis, s));
|
||||||
|
}
|
||||||
|
return concatenate(new_arrays, axis, s);
|
||||||
|
}
|
||||||
|
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
||||||
|
return stack(arrays, 0, s);
|
||||||
|
}
|
||||||
|
|
||||||
/** Pad an array with a constant value */
|
/** Pad an array with a constant value */
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -174,6 +174,10 @@ array concatenate(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Stack arrays along a new axis. */
|
||||||
|
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
||||||
|
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Permutes the dimensions according to the given axes. */
|
/** Permutes the dimensions according to the given axes. */
|
||||||
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
||||||
inline array transpose(
|
inline array transpose(
|
||||||
|
@ -49,6 +49,31 @@ std::vector<int> broadcast_shapes(
|
|||||||
return out_shape;
|
return out_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_same_shape(const std::vector<array>& arrays) {
|
||||||
|
if (arrays.empty())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
||||||
|
return (a.shape() == arrays[0].shape());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
int normalize_axis(int axis, int ndim) {
|
||||||
|
if (ndim <= 0) {
|
||||||
|
throw std::invalid_argument("Number of dimensions must be positive.");
|
||||||
|
}
|
||||||
|
if (axis < -ndim || axis >= ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Axis " << axis << " is out of bounds for array with " << ndim
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += ndim;
|
||||||
|
}
|
||||||
|
return axis;
|
||||||
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||||
os << "Device(";
|
os << "Device(";
|
||||||
switch (d.type) {
|
switch (d.type) {
|
||||||
|
@ -16,6 +16,15 @@ std::vector<int> broadcast_shapes(
|
|||||||
const std::vector<int>& s1,
|
const std::vector<int>& s1,
|
||||||
const std::vector<int>& s2);
|
const std::vector<int>& s2);
|
||||||
|
|
||||||
|
bool is_same_shape(const std::vector<array>& arrays);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the axis normalized to be in the range [0, ndim).
|
||||||
|
* Based on numpy's normalize_axis_index. See
|
||||||
|
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||||
|
*/
|
||||||
|
int normalize_axis(int axis, int ndim);
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||||
|
@ -2230,6 +2230,36 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The concatenated array.
|
array: The concatenated array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"stack",
|
||||||
|
[](const std::vector<array>& arrays,
|
||||||
|
std::optional<int> axis,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (axis.has_value()) {
|
||||||
|
return stack(arrays, axis.value(), s);
|
||||||
|
} else {
|
||||||
|
return stack(arrays, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"arrays"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"axis"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Stacks the arrays along a new axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arrays (list(array)): A list of arrays to stack.
|
||||||
|
axis (int, optional): The axis in the result array along which the
|
||||||
|
input arrays are stacked. Defaults to ``0``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The resulting stacked array.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"pad",
|
"pad",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -1371,6 +1371,37 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||||
|
|
||||||
|
def test_stack(self):
|
||||||
|
a = mx.ones((2,))
|
||||||
|
np_a = np.ones((2,))
|
||||||
|
b = mx.ones((2,))
|
||||||
|
np_b = np.ones((2,))
|
||||||
|
|
||||||
|
# One dimensional stack axis=0
|
||||||
|
c = mx.stack([a, b])
|
||||||
|
np_c = np.stack([np_a, np_b])
|
||||||
|
self.assertTrue(np.array_equal(c, np_c))
|
||||||
|
|
||||||
|
# One dimensional stack axis=1
|
||||||
|
c = mx.stack([a, b], axis=1)
|
||||||
|
np_c = np.stack([np_a, np_b], axis=1)
|
||||||
|
self.assertTrue(np.array_equal(c, np_c))
|
||||||
|
|
||||||
|
a = mx.ones((1, 2))
|
||||||
|
np_a = np.ones((1, 2))
|
||||||
|
b = mx.ones((1, 2))
|
||||||
|
np_b = np.ones((1, 2))
|
||||||
|
|
||||||
|
# Two dimensional stack axis=0
|
||||||
|
c = mx.stack([a, b])
|
||||||
|
np_c = np.stack([np_a, np_b])
|
||||||
|
self.assertTrue(np.array_equal(c, np_c))
|
||||||
|
|
||||||
|
# Two dimensional stack axis=1
|
||||||
|
c = mx.stack([a, b], axis=1)
|
||||||
|
np_c = np.stack([np_a, np_b], axis=1)
|
||||||
|
self.assertTrue(np.array_equal(c, np_c))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1989,6 +1989,35 @@ TEST_CASE("test where") {
|
|||||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test stack") {
|
||||||
|
auto x = array({});
|
||||||
|
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 0});
|
||||||
|
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{0, 1});
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3});
|
||||||
|
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 3});
|
||||||
|
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{3, 1});
|
||||||
|
|
||||||
|
auto y = array({4, 5, 6}, {3});
|
||||||
|
auto z = std::vector<array>{x, y};
|
||||||
|
CHECK_EQ(stack(z).shape(), std::vector<int>{2, 3});
|
||||||
|
CHECK_EQ(stack(z, 0).shape(), std::vector<int>{2, 3});
|
||||||
|
CHECK_EQ(stack(z, 1).shape(), std::vector<int>{3, 2});
|
||||||
|
CHECK_EQ(stack(z, -1).shape(), std::vector<int>{3, 2});
|
||||||
|
CHECK_EQ(stack(z, -2).shape(), std::vector<int>{2, 3});
|
||||||
|
|
||||||
|
CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking");
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3}, float16);
|
||||||
|
y = array({4, 5, 6}, {3}, int32);
|
||||||
|
CHECK_EQ(stack({x, y}, 0).dtype(), float16);
|
||||||
|
|
||||||
|
x = array({1, 2, 3}, {3}, int32);
|
||||||
|
y = array({4, 5, 6, 7}, {4}, int32);
|
||||||
|
CHECK_THROWS_MESSAGE(
|
||||||
|
stack({x, y}, 0), "All arrays must have the same shape and dtype");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test eye") {
|
TEST_CASE("test eye") {
|
||||||
auto eye_3 = eye(3);
|
auto eye_3 = eye(3);
|
||||||
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
|
||||||
|
@ -25,3 +25,38 @@ TEST_CASE("test type promotion") {
|
|||||||
CHECK_EQ(result_type(arrs), float32);
|
CHECK_EQ(result_type(arrs), float32);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test normalize axis") {
|
||||||
|
struct TestCase {
|
||||||
|
int axis;
|
||||||
|
int ndim;
|
||||||
|
int expected;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<TestCase> testCases = {
|
||||||
|
{0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}};
|
||||||
|
|
||||||
|
for (const auto& tc : testCases) {
|
||||||
|
CHECK_EQ(normalize_axis(tc.axis, tc.ndim), tc.expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_THROWS(normalize_axis(3, 3));
|
||||||
|
CHECK_THROWS(normalize_axis(-4, 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test is same size and shape") {
|
||||||
|
struct TestCase {
|
||||||
|
std::vector<array> a;
|
||||||
|
bool expected;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<TestCase> testCases = {
|
||||||
|
{{array({}), array({})}, true},
|
||||||
|
{{array({1}), array({1})}, true},
|
||||||
|
{{array({1, 2, 3}), array({1, 2, 4})}, true},
|
||||||
|
{{array({1, 2, 3}), array({1, 2})}, false}};
|
||||||
|
|
||||||
|
for (const auto& tc : testCases) {
|
||||||
|
CHECK_EQ(is_same_shape(tc.a), tc.expected);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user