Rename block sparse to gather (#793)

* rename block sparse to gather

* pin mlx version
This commit is contained in:
Awni Hannun
2024-05-23 19:47:35 -07:00
committed by GitHub
parent 69700d8431
commit ca7ce60c91
3 changed files with 4 additions and 4 deletions

View File

@@ -55,7 +55,7 @@ class QuantizedSwitchLinear(nn.Module):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.block_sparse_qmm(
x = mx.gather_qmm(
x,
self["weight"],
self["scales"],
@@ -98,7 +98,7 @@ class SwitchLinear(nn.Module):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x