mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
@@ -63,9 +63,10 @@ class Linear(Module):
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
x = mx.addmm(self.bias, x, self.weight.T)
|
||||
else:
|
||||
x = x @ self.weight.T
|
||||
return x
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user