mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 05:01:19 +08:00
use axes in tensordot (#525)
This commit is contained in:
parent
f326dd8334
commit
98c37d3a22
@ -3400,20 +3400,20 @@ void init_ops(py::module_& m) {
|
|||||||
"a"_a,
|
"a"_a,
|
||||||
"b"_a,
|
"b"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"dims"_a = 2,
|
"axes"_a = 2,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
|
tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Compute the tensor dot product along the specified axes.
|
Compute the tensor dot product along the specified axes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array
|
a (array): Input array
|
||||||
b (array): Input array
|
b (array): Input array
|
||||||
dims (int or list(list(int)), optional): The number of dimensions to
|
axes (int or list(list(int)), optional): The number of dimensions to
|
||||||
sum over. If an integer is provided, then sum over the last
|
sum over. If an integer is provided, then sum over the last
|
||||||
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
|
``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
|
||||||
``b``. If a list of lists is provided, then sum over the
|
``b``. If a list of lists is provided, then sum over the
|
||||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||||
|
|
||||||
|
@ -1644,23 +1644,23 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertCmpNumpy(
|
self.assertCmpNumpy(
|
||||||
[(3, 4, 5), (4, 3, 2)],
|
[(3, 4, 5), (4, 3, 2)],
|
||||||
mx.tensordot,
|
mx.tensordot,
|
||||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
np.tensordot,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
dims=([1, 0], [0, 1]),
|
axes=([1, 0], [0, 1]),
|
||||||
)
|
)
|
||||||
self.assertCmpNumpy(
|
self.assertCmpNumpy(
|
||||||
[(3, 4, 5), (4, 5, 6)],
|
[(3, 4, 5), (4, 5, 6)],
|
||||||
mx.tensordot,
|
mx.tensordot,
|
||||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
np.tensordot,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
dims=2,
|
axes=2,
|
||||||
)
|
)
|
||||||
self.assertCmpNumpy(
|
self.assertCmpNumpy(
|
||||||
[(3, 5, 4, 6), (6, 4, 5, 3)],
|
[(3, 5, 4, 6), (6, 4, 5, 3)],
|
||||||
mx.tensordot,
|
mx.tensordot,
|
||||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
np.tensordot,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
dims=([2, 1, 3], [1, 2, 0]),
|
axes=([2, 1, 3], [1, 2, 0]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_inner(self):
|
def test_inner(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user