mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 23:38:09 +08:00
Add Tensordot op (#344)
This commit is contained in:
@@ -3194,4 +3194,44 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The dequantized version of ``w``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tensordot",
|
||||
[](const array& a,
|
||||
const array& b,
|
||||
const std::variant<int, std::vector<std::vector<int>>>& dims,
|
||||
StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&dims); pv) {
|
||||
return tensordot(a, b, *pv, s);
|
||||
} else {
|
||||
auto x = std::get<std::vector<std::vector<int>>>(dims);
|
||||
if (x.size() != 2) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims must be a list of two lists.");
|
||||
}
|
||||
return tensordot(a, b, {x[0], x[1]}, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
"dims"_a = 2,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the tensor dot product along the specified axes.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
b (array): Input array
|
||||
dims (int or list(list(int)), optional): The number of dimensions to
|
||||
sum over. If an integer is provided, then sum over the last
|
||||
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
|
||||
``b``. If a list of lists is provided, then sum over the
|
||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||
|
||||
Returns:
|
||||
result (array): The tensor dot product.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
@@ -1547,6 +1547,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected_3 = np.repeat(data_3, 2, axis=0)
|
||||
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
||||
|
||||
def test_tensordot(self):
|
||||
x = mx.arange(60.0).reshape(3, 4, 5)
|
||||
y = mx.arange(24.0).reshape(4, 3, 2)
|
||||
z = mx.tensordot(x, y, dims=([1, 0], [0, 1]))
|
||||
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1]))))
|
||||
x = mx.random.normal((3, 4, 5))
|
||||
y = mx.random.normal((4, 5, 6))
|
||||
z = mx.tensordot(x, y, dims=2)
|
||||
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2)))
|
||||
x = mx.random.normal((3, 5, 4, 6))
|
||||
y = mx.random.normal((6, 4, 5, 3))
|
||||
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0]))
|
||||
self.assertEqualArray(
|
||||
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0])))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user