Rename block sparse to gather (#793)

* rename block sparse to gather

* pin mlx version
This commit is contained in:
Awni Hannun 2024-05-23 19:47:35 -07:00 committed by GitHub
parent 69700d8431
commit ca7ce60c91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 4 deletions

View File

@ -55,7 +55,7 @@ class QuantizedSwitchLinear(nn.Module):
return self.weight.shape[0] return self.weight.shape[0]
def __call__(self, x, indices): def __call__(self, x, indices):
x = mx.block_sparse_qmm( x = mx.gather_qmm(
x, x,
self["weight"], self["weight"],
self["scales"], self["scales"],
@ -98,7 +98,7 @@ class SwitchLinear(nn.Module):
return self.weight.shape[0] return self.weight.shape[0]
def __call__(self, x, indices): def __call__(self, x, indices):
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
if "bias" in self: if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2) x = x + mx.expand_dims(self["bias"][indices], -2)
return x return x

View File

@ -1,4 +1,4 @@
mlx>=0.13.1 mlx>=0.14
numpy numpy
transformers>=4.39.3 transformers>=4.39.3
protobuf protobuf

View File

@ -16,7 +16,7 @@ import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._errors import RepositoryNotFoundError
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models.base import KVCache from .models.base import KVCache