More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"full",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
const ScalarOrArray& vals,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::full({*pv}, to_array(vals, dtype), s);
} else {
return mx::full(
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
}
},
"shape"_a,
@@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"zeros",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::zeros({*pv}, t, s);
} else {
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
return mx::zeros(std::get<mx::Shape>(shape), t, s);
}
},
"shape"_a,
@@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"ones",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::ones({*pv}, t, s);
} else {
return mx::ones(std::get<std::vector<int>>(shape), t, s);
return mx::ones(std::get<mx::Shape>(shape), t, s);
}
},
"shape"_a,
@@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
m.def(
"split",
[](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections,
const std::variant<int, mx::Shape>& indices_or_sections,
int axis,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s);
} else {
return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
a, std::get<mx::Shape>(indices_or_sections), axis, s);
}
},
nb::arg(),
@@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"broadcast_to",
[](const ScalarOrArray& a,
const std::vector<int>& shape,
mx::StreamOrDevice s) {
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
return mx::broadcast_to(to_array(a), shape, s);
},
nb::arg(),
@@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) {
m.def(
"roll",
[](const mx::array& a,
const IntOrVec& shift,
const std::variant<int, mx::Shape>& shift,
const IntOrVec& axis,
mx::StreamOrDevice s) {
return std::visit(
[&](auto sh, auto ax) -> mx::array {
using T = decltype(ax);
using V = decltype(sh);
if constexpr (std::is_same_v<V, std::monostate>) {
throw std::invalid_argument(
"[roll] Expected two arguments but only one was given.");
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
return mx::roll(a, sh, s);
} else {
if constexpr (std::is_same_v<T, std::monostate>) {
return mx::roll(a, sh, s);
} else {
return mx::roll(a, sh, ax, s);
}
return mx::roll(a, sh, ax, s);
}
},
shift,