mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
nice tensordot for mlx c (#782)
This commit is contained in:
@@ -3555,17 +3555,17 @@ void init_ops(py::module_& m) {
|
||||
"tensordot",
|
||||
[](const array& a,
|
||||
const array& b,
|
||||
const std::variant<int, std::vector<std::vector<int>>>& dims,
|
||||
const std::variant<int, std::vector<std::vector<int>>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&dims); pv) {
|
||||
if (auto pv = std::get_if<int>(&axes); pv) {
|
||||
return tensordot(a, b, *pv, s);
|
||||
} else {
|
||||
auto x = std::get<std::vector<std::vector<int>>>(dims);
|
||||
auto& x = std::get<std::vector<std::vector<int>>>(axes);
|
||||
if (x.size() != 2) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims must be a list of two lists.");
|
||||
"[tensordot] axes must be a list of two lists.");
|
||||
}
|
||||
return tensordot(a, b, {x[0], x[1]}, s);
|
||||
return tensordot(a, b, x[0], x[1], s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
|
Reference in New Issue
Block a user