mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-08 18:06:37 +08:00
Rename block sparse to gather (#793)
* rename block sparse to gather * pin mlx version
This commit is contained in:
parent
69700d8431
commit
ca7ce60c91
@ -55,7 +55,7 @@ class QuantizedSwitchLinear(nn.Module):
|
|||||||
return self.weight.shape[0]
|
return self.weight.shape[0]
|
||||||
|
|
||||||
def __call__(self, x, indices):
|
def __call__(self, x, indices):
|
||||||
x = mx.block_sparse_qmm(
|
x = mx.gather_qmm(
|
||||||
x,
|
x,
|
||||||
self["weight"],
|
self["weight"],
|
||||||
self["scales"],
|
self["scales"],
|
||||||
@ -98,7 +98,7 @@ class SwitchLinear(nn.Module):
|
|||||||
return self.weight.shape[0]
|
return self.weight.shape[0]
|
||||||
|
|
||||||
def __call__(self, x, indices):
|
def __call__(self, x, indices):
|
||||||
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||||
return x
|
return x
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.13.1
|
mlx>=0.14
|
||||||
numpy
|
numpy
|
||||||
transformers>=4.39.3
|
transformers>=4.39.3
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -16,7 +16,7 @@ import mlx.nn as nn
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models.base import KVCache
|
from .models.base import KVCache
|
||||||
|
Loading…
Reference in New Issue
Block a user