mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Allow offset to be an mx.array for mx.fast.rope
(#1724)
* allow offset for rope * comment
This commit is contained in:
@@ -22,21 +22,25 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using Scalar = std::variant<int, double>;
|
||||
using Scalar = std::variant<bool, int, double>;
|
||||
|
||||
mx::Dtype scalar_to_dtype(Scalar scalar) {
|
||||
if (std::holds_alternative<int>(scalar)) {
|
||||
mx::Dtype scalar_to_dtype(Scalar s) {
|
||||
if (std::holds_alternative<int>(s)) {
|
||||
return mx::int32;
|
||||
} else {
|
||||
} else if (std::holds_alternative<double>(s)) {
|
||||
return mx::float32;
|
||||
} else {
|
||||
return mx::bool_;
|
||||
}
|
||||
}
|
||||
|
||||
double scalar_to_double(Scalar s) {
|
||||
if (std::holds_alternative<double>(s)) {
|
||||
return std::get<double>(s);
|
||||
if (auto pv = std::get_if<int>(&s); pv) {
|
||||
return static_cast<double>(*pv);
|
||||
} else if (auto pv = std::get_if<double>(&s); pv) {
|
||||
return *pv;
|
||||
} else {
|
||||
return static_cast<double>(std::get<int>(s));
|
||||
return static_cast<double>(std::get<bool>(s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1367,9 +1371,9 @@ void init_ops(nb::module_& m) {
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
"start"_a.noconvert(),
|
||||
"stop"_a.noconvert(),
|
||||
"step"_a.noconvert() = nb::none(),
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -1413,8 +1417,8 @@ void init_ops(nb::module_& m) {
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
"stop"_a.noconvert(),
|
||||
"step"_a.noconvert() = nb::none(),
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
|
Reference in New Issue
Block a user