mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 01:08:10 +08:00
fix addmm cpu (#2699)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -49,9 +48,15 @@ void matmul_bnns(
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
if (beta != 1.0 && beta != 0.0) {
|
||||
// scale the output
|
||||
for (auto i = 0; i < batch_size * M * N; ++i) {
|
||||
out[i] *= beta;
|
||||
}
|
||||
beta = 1.0;
|
||||
}
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
|
||||
@@ -717,8 +717,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
c = mx.ones((32, 32)).astype(t)
|
||||
a = mx.random.uniform(shape=(32, 32)).astype(t)
|
||||
b = mx.random.uniform(shape=(32, 32)).astype(t)
|
||||
out = mx.addmm(c, a, b)
|
||||
expected = a @ b + c
|
||||
out = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
|
||||
expected = 0.5 * (a @ b) + 2.0 * c
|
||||
self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
|
||||
Reference in New Issue
Block a user