Added mx.stack c++ frontend impl (#123)

* stack C++ operation + python bindings
This commit is contained in:
Jason
2023-12-14 16:21:19 -05:00
committed by GitHub
parent e5851e52b1
commit e28b57e371
9 changed files with 191 additions and 4 deletions

View File

@@ -574,11 +574,11 @@ array concatenate(
shape[ax] += a.shape(ax);
}
// Promote all the arrays to the same type
auto dtype = result_type(arrays);
return array(
shape,
arrays[0].dtype(),
std::make_unique<Concatenate>(to_stream(s), ax),
arrays);
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays);
}
array concatenate(
@@ -591,6 +591,29 @@ array concatenate(
return concatenate(flat_inputs, 0, s);
}
/** Stack arrays along a new axis */
array stack(
const std::vector<array>& arrays,
int axis,
StreamOrDevice s /* = {} */) {
if (arrays.empty()) {
throw std::invalid_argument("No arrays provided for stacking");
}
if (!is_same_shape(arrays)) {
throw std::invalid_argument("All arrays must have the same shape");
}
int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1);
std::vector<array> new_arrays;
new_arrays.reserve(arrays.size());
for (auto& a : arrays) {
new_arrays.emplace_back(expand_dims(a, normalized_axis, s));
}
return concatenate(new_arrays, axis, s);
}
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
return stack(arrays, 0, s);
}
/** Pad an array with a constant value */
array pad(
const array& a,

View File

@@ -174,6 +174,10 @@ array concatenate(
StreamOrDevice s = {});
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
/** Stack arrays along a new axis. */
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
/** Permutes the dimensions according to the given axes. */
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
inline array transpose(

View File

@@ -49,6 +49,31 @@ std::vector<int> broadcast_shapes(
return out_shape;
}
bool is_same_shape(const std::vector<array>& arrays) {
if (arrays.empty())
return true;
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
return (a.shape() == arrays[0].shape());
});
}
int normalize_axis(int axis, int ndim) {
if (ndim <= 0) {
throw std::invalid_argument("Number of dimensions must be positive.");
}
if (axis < -ndim || axis >= ndim) {
std::ostringstream msg;
msg << "Axis " << axis << " is out of bounds for array with " << ndim
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (axis < 0) {
axis += ndim;
}
return axis;
}
std::ostream& operator<<(std::ostream& os, const Device& d) {
os << "Device(";
switch (d.type) {

View File

@@ -16,6 +16,15 @@ std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2);
bool is_same_shape(const std::vector<array>& arrays);
/**
* Returns the axis normalized to be in the range [0, ndim).
* Based on numpy's normalize_axis_index. See
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
*/
int normalize_axis(int axis, int ndim);
std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);