Update GEMM (#424)

* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
This commit is contained in:
Jagrit Digani
2024-01-17 12:42:39 -08:00
committed by GitHub
parent 556cdf0e06
commit 78102a47ad
30 changed files with 2361 additions and 646 deletions

View File

@@ -63,9 +63,10 @@ class Linear(Module):
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
x = x @ self.weight.T
if "bias" in self:
x = x + self.bias
x = mx.addmm(self.bias, x, self.weight.T)
else:
x = x @ self.weight.T
return x