mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
remove monostate
This commit is contained in:
parent
bd1a11bef7
commit
3dc638bffb
@ -3198,12 +3198,9 @@ void init_ops(py::module_& m) {
|
|||||||
"tensordot",
|
"tensordot",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::variant<std::monostate, int, std::vector<std::vector<int>>>&
|
const std::variant<int, std::vector<std::vector<int>>>& dims,
|
||||||
dims,
|
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (std::holds_alternative<std::monostate>(dims)) {
|
if (auto pv = std::get_if<int>(&dims); pv) {
|
||||||
return tensordot(a, b, 2, s);
|
|
||||||
} else if (auto pv = std::get_if<int>(&dims); pv) {
|
|
||||||
return tensordot(a, b, *pv, s);
|
return tensordot(a, b, *pv, s);
|
||||||
} else {
|
} else {
|
||||||
return tensordot(
|
return tensordot(
|
||||||
|
Loading…
Reference in New Issue
Block a user