diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 20dbe181a..9a76188fa 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3215,5 +3215,21 @@ void init_ops(py::module_& m) { "dims"_a = 2, py::kw_only(), "stream"_a = none, - ""); + R"pbdoc( + tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array + + Compute the tensor dot product along the specified axes. + + Args: + a (array): Input array + b (array): Input array + dims (int or list(list(int)), optional): The number of dimensions to + sum over. If an integer is provided, then sum over the last + ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of + ``b``. If a list of lists is provided, then sum over the + corresponding dimensions of ``a`` and ``b``. (default: 2) + + Returns: + result (array): The tensor dot product. + )pbdoc"); }