mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
feat: add cross_product (#1252)
* feat: add cross_product * lint * python binding * refactor: Improve error message for cross_product function * refactor: more close to numpy cross product * refactor: improve error message for cross_product function * finish * fix acks * allow old numpy * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
4f46e9c997
commit
6af5ca35b2
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
|
@ -13,5 +13,6 @@ Linear Algebra
|
||||
norm
|
||||
cholesky
|
||||
cholesky_inv
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
|
@ -382,4 +382,76 @@ array cholesky_inv(
|
||||
}
|
||||
}
|
||||
|
||||
array cross(
|
||||
const array& a,
|
||||
const array& b,
|
||||
int axis /* = -1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto check_ax = [axis](const array& arr) {
|
||||
if (axis >= static_cast<int>(arr.ndim()) || axis + arr.ndim() < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cross] axis " << axis << " invalid for array with "
|
||||
<< arr.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (arr.shape(axis) < 2 || arr.shape(axis) > 3) {
|
||||
throw std::invalid_argument(
|
||||
"[linalg::cross] The specified axis must have size 2 or 3.");
|
||||
}
|
||||
};
|
||||
check_ax(a);
|
||||
check_ax(b);
|
||||
|
||||
bool a_2d = a.shape(axis) == 2;
|
||||
bool b_2d = b.shape(axis) == 2;
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto ashape = a.shape();
|
||||
auto bshape = b.shape();
|
||||
|
||||
ashape[axis < 0 ? axis + a.ndim() : axis] = 3;
|
||||
bshape[axis < 0 ? axis + b.ndim() : axis] = 3;
|
||||
auto out_shape = broadcast_shapes(ashape, bshape);
|
||||
|
||||
if (axis < 0) {
|
||||
axis += out_shape.size();
|
||||
}
|
||||
|
||||
out_shape[axis] = a_2d ? 2 : 3;
|
||||
auto a_ = broadcast_to(astype(a, out_type, s), out_shape, s);
|
||||
|
||||
out_shape[axis] = b_2d ? 2 : 3;
|
||||
auto b_ = broadcast_to(astype(b, out_type, s), out_shape, s);
|
||||
|
||||
auto a_splits = split(a_, a_2d ? 2 : 3, axis);
|
||||
auto b_splits = split(b_, b_2d ? 2 : 3, axis);
|
||||
|
||||
std::vector<array> outputs;
|
||||
if (a_2d && b_2d) {
|
||||
auto z = zeros_like(a_splits[0], s);
|
||||
outputs.push_back(z);
|
||||
outputs.push_back(z);
|
||||
} else if (b_2d) {
|
||||
outputs.push_back(negative(multiply(a_splits[2], b_splits[1], s), s));
|
||||
outputs.push_back(multiply(a_splits[2], b_splits[0], s));
|
||||
} else if (a_2d) {
|
||||
outputs.push_back(multiply(a_splits[1], b_splits[2], s));
|
||||
outputs.push_back(negative(multiply(a_splits[0], b_splits[2], s), s));
|
||||
} else {
|
||||
outputs.push_back(subtract(
|
||||
multiply(a_splits[1], b_splits[2], s),
|
||||
multiply(a_splits[2], b_splits[1], s),
|
||||
s));
|
||||
outputs.push_back(subtract(
|
||||
multiply(a_splits[2], b_splits[0], s),
|
||||
multiply(a_splits[0], b_splits[2], s),
|
||||
s));
|
||||
}
|
||||
outputs.push_back(subtract(
|
||||
multiply(a_splits[0], b_splits[1], s),
|
||||
multiply(a_splits[1], b_splits[0], s),
|
||||
s));
|
||||
return concatenate(outputs, axis, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
@ -74,4 +74,13 @@ array pinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Compute the cross product of two arrays along the given axis.
|
||||
*/
|
||||
array cross(
|
||||
const array& a,
|
||||
const array& b,
|
||||
int axis = -1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
@ -377,4 +377,32 @@ void init_linalg(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: ``aplus`` such that ``a @ aplus @ a = a``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cross",
|
||||
&cross,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
"axis"_a = -1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the cross product of two arrays along a specified axis.
|
||||
|
||||
The cross product is defined for arrays with size 2 or 3 in the
|
||||
specified axis. If the size is 2 then the third value is assumed
|
||||
to be zero.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
axis (int, optional): Axis along which to compute the cross
|
||||
product. Default: ``-1``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The cross product of ``a`` and ``b`` along the specified axis.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -220,6 +220,54 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
for M, M_inv in zip(AB, AB_inv):
|
||||
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4))
|
||||
|
||||
def test_cross_product(self):
|
||||
a = mx.array([1.0, 2.0, 3.0])
|
||||
b = mx.array([4.0, 5.0, 6.0])
|
||||
result = mx.linalg.cross(a, b)
|
||||
expected = np.cross(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test with negative values
|
||||
a = mx.array([-1.0, -2.0, -3.0])
|
||||
b = mx.array([4.0, -5.0, 6.0])
|
||||
result = mx.linalg.cross(a, b)
|
||||
expected = np.cross(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test with integer values
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
result = mx.linalg.cross(a, b)
|
||||
expected = np.cross(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test with 2D arrays and axis parameter
|
||||
a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]])
|
||||
result = mx.linalg.cross(a, b, axis=1)
|
||||
expected = np.cross(a, b, axis=1)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test with broadcast
|
||||
a = mx.random.uniform(shape=(2, 1, 3))
|
||||
b = mx.random.uniform(shape=(1, 2, 3))
|
||||
result = mx.linalg.cross(a, b)
|
||||
expected = np.cross(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Type promotion
|
||||
a = mx.array([1.0, 2.0, 3.0])
|
||||
b = mx.array([4, 5, 6])
|
||||
result = mx.linalg.cross(a, b)
|
||||
expected = np.cross(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test with incorrect vector size (should raise an exception)
|
||||
a = mx.array([1.0])
|
||||
b = mx.array([4.0])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.cross(a, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -390,3 +390,48 @@ TEST_CASE("test matrix pseudo-inverse") {
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test cross product") {
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
// Test for vectors of length 3
|
||||
array a = array({1.0, 2.0, 3.0});
|
||||
array b = array({4.0, 5.0, 6.0});
|
||||
|
||||
array expected = array(
|
||||
{2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0});
|
||||
|
||||
array result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
|
||||
// Test for vectors of length 3 with negative values
|
||||
a = array({-1.0, -2.0, -3.0});
|
||||
b = array({4.0, -5.0, 6.0});
|
||||
|
||||
expected = array(
|
||||
{-2.0 * 6.0 - (-3.0 * -5.0),
|
||||
-3.0 * 4.0 - (-1.0 * 6.0),
|
||||
-1.0 * -5.0 - (-2.0 * 4.0)});
|
||||
|
||||
result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
|
||||
// Test for incorrect vector size (should throw)
|
||||
b = array({1.0, 2.0});
|
||||
expected = array(
|
||||
{-2.0 * 0.0 - (-3.0 * 2.0),
|
||||
-3.0 * 1.0 - (-1.0 * 0.0),
|
||||
-1.0 * 2.0 - (-2.0 * 1.0)});
|
||||
|
||||
result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
|
||||
// Test for vectors of length 3 with integer values
|
||||
a = array({1, 2, 3});
|
||||
b = array({4, 5, 6});
|
||||
|
||||
expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4});
|
||||
|
||||
result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user