mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
use axes in tensordot (#525)
This commit is contained in:
@@ -3400,20 +3400,20 @@ void init_ops(py::module_& m) {
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
"dims"_a = 2,
|
||||
"axes"_a = 2,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the tensor dot product along the specified axes.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
b (array): Input array
|
||||
dims (int or list(list(int)), optional): The number of dimensions to
|
||||
axes (int or list(list(int)), optional): The number of dimensions to
|
||||
sum over. If an integer is provided, then sum over the last
|
||||
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
|
||||
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||
``b``. If a list of lists is provided, then sum over the
|
||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||
|
||||
|
Reference in New Issue
Block a user