fix addmm cpu (#2699)

This commit is contained in:
Awni Hannun
2025-10-27 11:33:32 -07:00
committed by GitHub
parent 895217f25b
commit c4767d110f
2 changed files with 9 additions and 4 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
@@ -49,9 +48,15 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,

View File

@@ -717,8 +717,8 @@ class TestBlas(mlx_tests.MLXTestCase):
c = mx.ones((32, 32)).astype(t)
a = mx.random.uniform(shape=(32, 32)).astype(t)
b = mx.random.uniform(shape=(32, 32)).astype(t)
out = mx.addmm(c, a, b)
expected = a @ b + c
out = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
expected = 0.5 * (a @ b) + 2.0 * c
self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))
def test_addmm_grad(self):