Allow offset to be an mx.array for mx.fast.rope (#1724)

* allow offset for rope

* comment
This commit is contained in:
Awni Hannun
2024-12-19 15:51:44 -08:00
committed by GitHub
parent c3628eea49
commit 0308e9af71
8 changed files with 97 additions and 52 deletions

View File

@@ -22,21 +22,25 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using Scalar = std::variant<int, double>;
using Scalar = std::variant<bool, int, double>;
mx::Dtype scalar_to_dtype(Scalar scalar) {
if (std::holds_alternative<int>(scalar)) {
mx::Dtype scalar_to_dtype(Scalar s) {
if (std::holds_alternative<int>(s)) {
return mx::int32;
} else {
} else if (std::holds_alternative<double>(s)) {
return mx::float32;
} else {
return mx::bool_;
}
}
double scalar_to_double(Scalar s) {
if (std::holds_alternative<double>(s)) {
return std::get<double>(s);
if (auto pv = std::get_if<int>(&s); pv) {
return static_cast<double>(*pv);
} else if (auto pv = std::get_if<double>(&s); pv) {
return *pv;
} else {
return static_cast<double>(std::get<int>(s));
return static_cast<double>(std::get<bool>(s));
}
}
@@ -1367,9 +1371,9 @@ void init_ops(nb::module_& m) {
dtype,
s);
},
"start"_a,
"stop"_a,
"step"_a = nb::none(),
"start"_a.noconvert(),
"stop"_a.noconvert(),
"step"_a.noconvert() = nb::none(),
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1413,8 +1417,8 @@ void init_ops(nb::module_& m) {
dtype,
s);
},
"stop"_a,
"step"_a = nb::none(),
"stop"_a.noconvert(),
"step"_a.noconvert() = nb::none(),
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),