mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
@@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"full",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
const ScalarOrArray& vals,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::full({*pv}, to_array(vals, dtype), s);
|
||||
} else {
|
||||
return mx::full(
|
||||
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
|
||||
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"zeros",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
auto t = dtype.value_or(mx::float32);
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::zeros({*pv}, t, s);
|
||||
} else {
|
||||
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
|
||||
return mx::zeros(std::get<mx::Shape>(shape), t, s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ones",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
auto t = dtype.value_or(mx::float32);
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::ones({*pv}, t, s);
|
||||
} else {
|
||||
return mx::ones(std::get<std::vector<int>>(shape), t, s);
|
||||
return mx::ones(std::get<mx::Shape>(shape), t, s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"split",
|
||||
[](const mx::array& a,
|
||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
||||
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||
int axis,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||
return mx::split(a, *pv, axis, s);
|
||||
} else {
|
||||
return mx::split(
|
||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
||||
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"broadcast_to",
|
||||
[](const ScalarOrArray& a,
|
||||
const std::vector<int>& shape,
|
||||
mx::StreamOrDevice s) {
|
||||
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
|
||||
return mx::broadcast_to(to_array(a), shape, s);
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"roll",
|
||||
[](const mx::array& a,
|
||||
const IntOrVec& shift,
|
||||
const std::variant<int, mx::Shape>& shift,
|
||||
const IntOrVec& axis,
|
||||
mx::StreamOrDevice s) {
|
||||
return std::visit(
|
||||
[&](auto sh, auto ax) -> mx::array {
|
||||
using T = decltype(ax);
|
||||
using V = decltype(sh);
|
||||
|
||||
if constexpr (std::is_same_v<V, std::monostate>) {
|
||||
throw std::invalid_argument(
|
||||
"[roll] Expected two arguments but only one was given.");
|
||||
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
|
||||
return mx::roll(a, sh, s);
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, std::monostate>) {
|
||||
return mx::roll(a, sh, s);
|
||||
} else {
|
||||
return mx::roll(a, sh, ax, s);
|
||||
}
|
||||
return mx::roll(a, sh, ax, s);
|
||||
}
|
||||
},
|
||||
shift,
|
||||
|
Reference in New Issue
Block a user