No reshapes in quantized embedding (#1682)

* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
This commit is contained in:
Awni Hannun
2024-12-09 18:57:38 -08:00
committed by GitHub
parent 87d7a2520e
commit 29a620cab2
6 changed files with 26 additions and 12 deletions

View File

@@ -1395,11 +1395,12 @@ void init_ops(nb::module_& m) {
m.def(
"take",
[](const array& a,
const std::variant<int, array>& indices,
const std::variant<nb::int_, array>& indices,
const std::optional<int>& axis,
StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices); pv) {
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
if (auto pv = std::get_if<nb::int_>(&indices); pv) {
auto idx = nb::cast<int>(*pv);
return axis ? take(a, idx, axis.value(), s) : take(a, idx, s);
} else {
auto indices_ = std::get<array>(indices);
return axis ? take(a, indices_, axis.value(), s)

View File

@@ -459,13 +459,13 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"permuation",
[](const std::variant<int, array>& x,
[](const std::variant<nb::int_, array>& x,
int axis,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (auto pv = std::get_if<int>(&x); pv) {
return permutation(*pv, key, s);
if (auto pv = std::get_if<nb::int_>(&x); pv) {
return permutation(nb::cast<int>(*pv), key, s);
} else {
return permutation(std::get<array>(x), axis, key, s);
}