mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix addmm cpu (#2699)
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
@@ -49,9 +48,15 @@ void matmul_bnns(
|
|||||||
size_t K = a_shape[ndim - 1];
|
size_t K = a_shape[ndim - 1];
|
||||||
|
|
||||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||||
|
|
||||||
#pragma GCC diagnostic push
|
#pragma GCC diagnostic push
|
||||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
|
if (beta != 1.0 && beta != 0.0) {
|
||||||
|
// scale the output
|
||||||
|
for (auto i = 0; i < batch_size * M * N; ++i) {
|
||||||
|
out[i] *= beta;
|
||||||
|
}
|
||||||
|
beta = 1.0;
|
||||||
|
}
|
||||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||||
/* float alpha = */ alpha,
|
/* float alpha = */ alpha,
|
||||||
/* float beta = */ beta,
|
/* float beta = */ beta,
|
||||||
|
|||||||
@@ -717,8 +717,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
c = mx.ones((32, 32)).astype(t)
|
c = mx.ones((32, 32)).astype(t)
|
||||||
a = mx.random.uniform(shape=(32, 32)).astype(t)
|
a = mx.random.uniform(shape=(32, 32)).astype(t)
|
||||||
b = mx.random.uniform(shape=(32, 32)).astype(t)
|
b = mx.random.uniform(shape=(32, 32)).astype(t)
|
||||||
out = mx.addmm(c, a, b)
|
out = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
|
||||||
expected = a @ b + c
|
expected = 0.5 * (a @ b) + 2.0 * c
|
||||||
self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))
|
self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user