mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
@@ -1395,11 +1395,12 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"take",
|
||||
[](const array& a,
|
||||
const std::variant<int, array>& indices,
|
||||
const std::variant<nb::int_, array>& indices,
|
||||
const std::optional<int>& axis,
|
||||
StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices); pv) {
|
||||
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
|
||||
if (auto pv = std::get_if<nb::int_>(&indices); pv) {
|
||||
auto idx = nb::cast<int>(*pv);
|
||||
return axis ? take(a, idx, axis.value(), s) : take(a, idx, s);
|
||||
} else {
|
||||
auto indices_ = std::get<array>(indices);
|
||||
return axis ? take(a, indices_, axis.value(), s)
|
||||
|
Reference in New Issue
Block a user