mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix addmm with empty matrices and beta != 1.0 (#2715)
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
#include <cstring>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
@@ -135,15 +137,58 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle empty matrix case (K=0)
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
auto& c = inputs[2];
|
||||
if (beta_ == 1.0f) {
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy_cpu(c, out, ctype, stream());
|
||||
} else {
|
||||
array beta_scalar = array(beta_, c.dtype());
|
||||
auto bopt = get_binary_op_type(c, beta_scalar);
|
||||
set_binary_op_output_data(c, beta_scalar, out, bopt);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_input_array(beta_scalar);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([c = array::unsafe_weak_copy(c),
|
||||
beta_scalar = array::unsafe_weak_copy(beta_scalar),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Unsupported dtype for beta scaling");
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(std::move(beta_scalar));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy_cpu(c, out, ctype, stream());
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
return;
|
||||
}
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@@ -925,19 +926,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy c into out and return
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Handle empty matrix case (K=0)
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
copy_gpu(
|
||||
inputs[2],
|
||||
out,
|
||||
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
auto& c = inputs[2];
|
||||
if (beta_ == 1.0f) {
|
||||
copy_gpu(
|
||||
c,
|
||||
out,
|
||||
c.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
array beta_scalar = array(beta_, c.dtype());
|
||||
binary_op_gpu({c, beta_scalar}, out, "Multiply", s);
|
||||
d.add_temporary(std::move(beta_scalar), s.index);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
Reference in New Issue
Block a user