feat: add cross_product (#1252)

* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-09-17 21:12:43 +01:00
committed by GitHub
parent 4f46e9c997
commit 6af5ca35b2
7 changed files with 204 additions and 1 deletions

View File

@@ -377,4 +377,32 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: ``aplus`` such that ``a @ aplus @ a = a``
)pbdoc");
m.def(
"cross",
&cross,
"a"_a,
"b"_a,
"axis"_a = -1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the cross product of two arrays along a specified axis.
The cross product is defined for arrays with size 2 or 3 in the
specified axis. If the size is 2 then the third value is assumed
to be zero.
Args:
a (array): Input array.
b (array): Input array.
axis (int, optional): Axis along which to compute the cross
product. Default: ``-1``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The cross product of ``a`` and ``b`` along the specified axis.
)pbdoc");
}