feat: add cross_product (#1252)

* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan 2024-09-17 21:12:43 +01:00 committed by GitHub
parent 4f46e9c997
commit 6af5ca35b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 204 additions and 1 deletions

View File

@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.

View File

@ -13,5 +13,6 @@ Linear Algebra
norm
cholesky
cholesky_inv
cross
qr
svd

View File

@ -382,4 +382,76 @@ array cholesky_inv(
}
}
array cross(
const array& a,
const array& b,
int axis /* = -1 */,
StreamOrDevice s /* = {} */) {
auto check_ax = [axis](const array& arr) {
if (axis >= static_cast<int>(arr.ndim()) || axis + arr.ndim() < 0) {
std::ostringstream msg;
msg << "[linalg::cross] axis " << axis << " invalid for array with "
<< arr.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (arr.shape(axis) < 2 || arr.shape(axis) > 3) {
throw std::invalid_argument(
"[linalg::cross] The specified axis must have size 2 or 3.");
}
};
check_ax(a);
check_ax(b);
bool a_2d = a.shape(axis) == 2;
bool b_2d = b.shape(axis) == 2;
auto out_type = promote_types(a.dtype(), b.dtype());
auto ashape = a.shape();
auto bshape = b.shape();
ashape[axis < 0 ? axis + a.ndim() : axis] = 3;
bshape[axis < 0 ? axis + b.ndim() : axis] = 3;
auto out_shape = broadcast_shapes(ashape, bshape);
if (axis < 0) {
axis += out_shape.size();
}
out_shape[axis] = a_2d ? 2 : 3;
auto a_ = broadcast_to(astype(a, out_type, s), out_shape, s);
out_shape[axis] = b_2d ? 2 : 3;
auto b_ = broadcast_to(astype(b, out_type, s), out_shape, s);
auto a_splits = split(a_, a_2d ? 2 : 3, axis);
auto b_splits = split(b_, b_2d ? 2 : 3, axis);
std::vector<array> outputs;
if (a_2d && b_2d) {
auto z = zeros_like(a_splits[0], s);
outputs.push_back(z);
outputs.push_back(z);
} else if (b_2d) {
outputs.push_back(negative(multiply(a_splits[2], b_splits[1], s), s));
outputs.push_back(multiply(a_splits[2], b_splits[0], s));
} else if (a_2d) {
outputs.push_back(multiply(a_splits[1], b_splits[2], s));
outputs.push_back(negative(multiply(a_splits[0], b_splits[2], s), s));
} else {
outputs.push_back(subtract(
multiply(a_splits[1], b_splits[2], s),
multiply(a_splits[2], b_splits[1], s),
s));
outputs.push_back(subtract(
multiply(a_splits[2], b_splits[0], s),
multiply(a_splits[0], b_splits[2], s),
s));
}
outputs.push_back(subtract(
multiply(a_splits[0], b_splits[1], s),
multiply(a_splits[1], b_splits[0], s),
s));
return concatenate(outputs, axis, s);
}
} // namespace mlx::core::linalg

View File

@ -74,4 +74,13 @@ array pinv(const array& a, StreamOrDevice s = {});
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
/**
* Compute the cross product of two arrays along the given axis.
*/
array cross(
const array& a,
const array& b,
int axis = -1,
StreamOrDevice s = {});
} // namespace mlx::core::linalg

View File

@ -377,4 +377,32 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: ``aplus`` such that ``a @ aplus @ a = a``
)pbdoc");
m.def(
"cross",
&cross,
"a"_a,
"b"_a,
"axis"_a = -1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the cross product of two arrays along a specified axis.
The cross product is defined for arrays with size 2 or 3 in the
specified axis. If the size is 2 then the third value is assumed
to be zero.
Args:
a (array): Input array.
b (array): Input array.
axis (int, optional): Axis along which to compute the cross
product. Default: ``-1``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The cross product of ``a`` and ``b`` along the specified axis.
)pbdoc");
}

View File

@ -220,6 +220,54 @@ class TestLinalg(mlx_tests.MLXTestCase):
for M, M_inv in zip(AB, AB_inv):
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4))
def test_cross_product(self):
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with negative values
a = mx.array([-1.0, -2.0, -3.0])
b = mx.array([4.0, -5.0, 6.0])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with integer values
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with 2D arrays and axis parameter
a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]])
result = mx.linalg.cross(a, b, axis=1)
expected = np.cross(a, b, axis=1)
self.assertTrue(np.allclose(result, expected))
# Test with broadcast
a = mx.random.uniform(shape=(2, 1, 3))
b = mx.random.uniform(shape=(1, 2, 3))
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Type promotion
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4, 5, 6])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with incorrect vector size (should raise an exception)
a = mx.array([1.0])
b = mx.array([4.0])
with self.assertRaises(ValueError):
mx.linalg.cross(a, b)
if __name__ == "__main__":
unittest.main()

View File

@ -390,3 +390,48 @@ TEST_CASE("test matrix pseudo-inverse") {
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
}
}
TEST_CASE("test cross product") {
using namespace mlx::core::linalg;
// Test for vectors of length 3
array a = array({1.0, 2.0, 3.0});
array b = array({4.0, 5.0, 6.0});
array expected = array(
{2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0});
array result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
// Test for vectors of length 3 with negative values
a = array({-1.0, -2.0, -3.0});
b = array({4.0, -5.0, 6.0});
expected = array(
{-2.0 * 6.0 - (-3.0 * -5.0),
-3.0 * 4.0 - (-1.0 * 6.0),
-1.0 * -5.0 - (-2.0 * 4.0)});
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
// Test for incorrect vector size (should throw)
b = array({1.0, 2.0});
expected = array(
{-2.0 * 0.0 - (-3.0 * 2.0),
-3.0 * 1.0 - (-1.0 * 0.0),
-1.0 * 2.0 - (-2.0 * 1.0)});
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
// Test for vectors of length 3 with integer values
a = array({1, 2, 3});
b = array({4, 5, 6});
expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4});
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
}