diff --git a/llms/mlx_lm/models/switch_layers.py b/llms/mlx_lm/models/switch_layers.py index cad99ec0..00aa65d8 100644 --- a/llms/mlx_lm/models/switch_layers.py +++ b/llms/mlx_lm/models/switch_layers.py @@ -55,7 +55,7 @@ class QuantizedSwitchLinear(nn.Module): return self.weight.shape[0] def __call__(self, x, indices): - x = mx.block_sparse_qmm( + x = mx.gather_qmm( x, self["weight"], self["scales"], @@ -98,7 +98,7 @@ class SwitchLinear(nn.Module): return self.weight.shape[0] def __call__(self, x, indices): - x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) + x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) if "bias" in self: x = x + mx.expand_dims(self["bias"][indices], -2) return x diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 226ac053..e6cb70de 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.13.1 +mlx>=0.14 numpy transformers>=4.39.3 protobuf diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d665325e..82c00fca 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,7 +16,7 @@ import mlx.nn as nn from huggingface_hub import snapshot_download from huggingface_hub.utils._errors import RepositoryNotFoundError from mlx.utils import tree_flatten -from transformers import AutoTokenizer, PreTrainedTokenizer +from transformers import PreTrainedTokenizer # Local imports from .models.base import KVCache