feat: add cross_product (#1252)

* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-09-17 21:12:43 +01:00
committed by GitHub
parent 4f46e9c997
commit 6af5ca35b2
7 changed files with 204 additions and 1 deletions

View File

@@ -382,4 +382,76 @@ array cholesky_inv(
}
}
array cross(
const array& a,
const array& b,
int axis /* = -1 */,
StreamOrDevice s /* = {} */) {
auto check_ax = [axis](const array& arr) {
if (axis >= static_cast<int>(arr.ndim()) || axis + arr.ndim() < 0) {
std::ostringstream msg;
msg << "[linalg::cross] axis " << axis << " invalid for array with "
<< arr.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (arr.shape(axis) < 2 || arr.shape(axis) > 3) {
throw std::invalid_argument(
"[linalg::cross] The specified axis must have size 2 or 3.");
}
};
check_ax(a);
check_ax(b);
bool a_2d = a.shape(axis) == 2;
bool b_2d = b.shape(axis) == 2;
auto out_type = promote_types(a.dtype(), b.dtype());
auto ashape = a.shape();
auto bshape = b.shape();
ashape[axis < 0 ? axis + a.ndim() : axis] = 3;
bshape[axis < 0 ? axis + b.ndim() : axis] = 3;
auto out_shape = broadcast_shapes(ashape, bshape);
if (axis < 0) {
axis += out_shape.size();
}
out_shape[axis] = a_2d ? 2 : 3;
auto a_ = broadcast_to(astype(a, out_type, s), out_shape, s);
out_shape[axis] = b_2d ? 2 : 3;
auto b_ = broadcast_to(astype(b, out_type, s), out_shape, s);
auto a_splits = split(a_, a_2d ? 2 : 3, axis);
auto b_splits = split(b_, b_2d ? 2 : 3, axis);
std::vector<array> outputs;
if (a_2d && b_2d) {
auto z = zeros_like(a_splits[0], s);
outputs.push_back(z);
outputs.push_back(z);
} else if (b_2d) {
outputs.push_back(negative(multiply(a_splits[2], b_splits[1], s), s));
outputs.push_back(multiply(a_splits[2], b_splits[0], s));
} else if (a_2d) {
outputs.push_back(multiply(a_splits[1], b_splits[2], s));
outputs.push_back(negative(multiply(a_splits[0], b_splits[2], s), s));
} else {
outputs.push_back(subtract(
multiply(a_splits[1], b_splits[2], s),
multiply(a_splits[2], b_splits[1], s),
s));
outputs.push_back(subtract(
multiply(a_splits[2], b_splits[0], s),
multiply(a_splits[0], b_splits[2], s),
s));
}
outputs.push_back(subtract(
multiply(a_splits[0], b_splits[1], s),
multiply(a_splits[1], b_splits[0], s),
s));
return concatenate(outputs, axis, s);
}
} // namespace mlx::core::linalg

View File

@@ -74,4 +74,13 @@ array pinv(const array& a, StreamOrDevice s = {});
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
/**
* Compute the cross product of two arrays along the given axis.
*/
array cross(
const array& a,
const array& b,
int axis = -1,
StreamOrDevice s = {});
} // namespace mlx::core::linalg