mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
@@ -459,13 +459,13 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permuation",
|
||||
[](const std::variant<int, array>& x,
|
||||
[](const std::variant<nb::int_, array>& x,
|
||||
int axis,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (auto pv = std::get_if<int>(&x); pv) {
|
||||
return permutation(*pv, key, s);
|
||||
if (auto pv = std::get_if<nb::int_>(&x); pv) {
|
||||
return permutation(nb::cast<int>(*pv), key, s);
|
||||
} else {
|
||||
return permutation(std::get<array>(x), axis, key, s);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user