Add Tensordot op (#344)

This commit is contained in:
Diogo
2024-01-02 20:15:00 -05:00
committed by GitHub
parent af66a09bde
commit 0782a4573a
7 changed files with 198 additions and 1 deletions

View File

@@ -3194,4 +3194,44 @@ void init_ops(py::module_& m) {
Returns:
result (array): The dequantized version of ``w``
)pbdoc");
m.def(
"tensordot",
[](const array& a,
const array& b,
const std::variant<int, std::vector<std::vector<int>>>& dims,
StreamOrDevice s) {
if (auto pv = std::get_if<int>(&dims); pv) {
return tensordot(a, b, *pv, s);
} else {
auto x = std::get<std::vector<std::vector<int>>>(dims);
if (x.size() != 2) {
throw std::invalid_argument(
"[tensordot] dims must be a list of two lists.");
}
return tensordot(a, b, {x[0], x[1]}, s);
}
},
"a"_a,
"b"_a,
py::pos_only(),
"dims"_a = 2,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
Compute the tensor dot product along the specified axes.
Args:
a (array): Input array
b (array): Input array
dims (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. (default: 2)
Returns:
result (array): The tensor dot product.
)pbdoc");
}