mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
This commit is contained in:
31
mlx/ops.cpp
31
mlx/ops.cpp
@@ -574,11 +574,11 @@ array concatenate(
|
||||
shape[ax] += a.shape(ax);
|
||||
}
|
||||
|
||||
// Promote all the arrays to the same type
|
||||
auto dtype = result_type(arrays);
|
||||
|
||||
return array(
|
||||
shape,
|
||||
arrays[0].dtype(),
|
||||
std::make_unique<Concatenate>(to_stream(s), ax),
|
||||
arrays);
|
||||
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays);
|
||||
}
|
||||
|
||||
array concatenate(
|
||||
@@ -591,6 +591,29 @@ array concatenate(
|
||||
return concatenate(flat_inputs, 0, s);
|
||||
}
|
||||
|
||||
/** Stack arrays along a new axis */
|
||||
array stack(
|
||||
const std::vector<array>& arrays,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (arrays.empty()) {
|
||||
throw std::invalid_argument("No arrays provided for stacking");
|
||||
}
|
||||
if (!is_same_shape(arrays)) {
|
||||
throw std::invalid_argument("All arrays must have the same shape");
|
||||
}
|
||||
int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1);
|
||||
std::vector<array> new_arrays;
|
||||
new_arrays.reserve(arrays.size());
|
||||
for (auto& a : arrays) {
|
||||
new_arrays.emplace_back(expand_dims(a, normalized_axis, s));
|
||||
}
|
||||
return concatenate(new_arrays, axis, s);
|
||||
}
|
||||
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
||||
return stack(arrays, 0, s);
|
||||
}
|
||||
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
|
@@ -174,6 +174,10 @@ array concatenate(
|
||||
StreamOrDevice s = {});
|
||||
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
|
||||
/** Stack arrays along a new axis. */
|
||||
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
||||
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
||||
|
||||
/** Permutes the dimensions according to the given axes. */
|
||||
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
||||
inline array transpose(
|
||||
|
@@ -49,6 +49,31 @@ std::vector<int> broadcast_shapes(
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays) {
|
||||
if (arrays.empty())
|
||||
return true;
|
||||
|
||||
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
||||
return (a.shape() == arrays[0].shape());
|
||||
});
|
||||
}
|
||||
|
||||
int normalize_axis(int axis, int ndim) {
|
||||
if (ndim <= 0) {
|
||||
throw std::invalid_argument("Number of dimensions must be positive.");
|
||||
}
|
||||
if (axis < -ndim || axis >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "Axis " << axis << " is out of bounds for array with " << ndim
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += ndim;
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||
os << "Device(";
|
||||
switch (d.type) {
|
||||
|
@@ -16,6 +16,15 @@ std::vector<int> broadcast_shapes(
|
||||
const std::vector<int>& s1,
|
||||
const std::vector<int>& s2);
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays);
|
||||
|
||||
/**
|
||||
* Returns the axis normalized to be in the range [0, ndim).
|
||||
* Based on numpy's normalize_axis_index. See
|
||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||
*/
|
||||
int normalize_axis(int axis, int ndim);
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||
|
Reference in New Issue
Block a user