mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-05 16:26:03 +08:00
Rename block sparse to gather (#793)
* rename block sparse to gather * pin mlx version
This commit is contained in:
@@ -55,7 +55,7 @@ class QuantizedSwitchLinear(nn.Module):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_qmm(
|
||||
x = mx.gather_qmm(
|
||||
x,
|
||||
self["weight"],
|
||||
self["scales"],
|
||||
@@ -98,7 +98,7 @@ class SwitchLinear(nn.Module):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
Reference in New Issue
Block a user