nice tensordot for mlx c (#782)

This commit is contained in:
Awni Hannun
2024-03-04 09:51:02 -08:00
committed by GitHub
parent 6a665ea6ed
commit 5121f028d9
4 changed files with 32 additions and 33 deletions

View File

@@ -3555,17 +3555,17 @@ void init_ops(py::module_& m) {
"tensordot",
[](const array& a,
const array& b,
const std::variant<int, std::vector<std::vector<int>>>& dims,
const std::variant<int, std::vector<std::vector<int>>>& axes,
StreamOrDevice s) {
if (auto pv = std::get_if<int>(&dims); pv) {
if (auto pv = std::get_if<int>(&axes); pv) {
return tensordot(a, b, *pv, s);
} else {
auto x = std::get<std::vector<std::vector<int>>>(dims);
auto& x = std::get<std::vector<std::vector<int>>>(axes);
if (x.size() != 2) {
throw std::invalid_argument(
"[tensordot] dims must be a list of two lists.");
"[tensordot] axes must be a list of two lists.");
}
return tensordot(a, b, {x[0], x[1]}, s);
return tensordot(a, b, x[0], x[1], s);
}
},
"a"_a,