mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-08 18:06:37 +08:00
Rename block sparse to gather (#793)
* rename block sparse to gather * pin mlx version
This commit is contained in:
parent
69700d8431
commit
ca7ce60c91
@ -55,7 +55,7 @@ class QuantizedSwitchLinear(nn.Module):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_qmm(
|
||||
x = mx.gather_qmm(
|
||||
x,
|
||||
self["weight"],
|
||||
self["scales"],
|
||||
@ -98,7 +98,7 @@ class SwitchLinear(nn.Module):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.13.1
|
||||
mlx>=0.14
|
||||
numpy
|
||||
transformers>=4.39.3
|
||||
protobuf
|
||||
|
@ -16,7 +16,7 @@ import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache
|
||||
|
Loading…
Reference in New Issue
Block a user