Fix addmm with empty matrices and beta != 1.0 (#2715)

This commit is contained in:
Harsh Sutaria
2025-11-03 17:16:15 -05:00
committed by GitHub
parent 1ff2b713b6
commit 50fa315d18
3 changed files with 105 additions and 16 deletions

View File

@@ -2,6 +2,8 @@
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
@@ -135,15 +137,58 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) {
auto& c = inputs[2];
if (beta_ == 1.0f) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
} else {
array beta_scalar = array(beta_, c.dtype());
auto bopt = get_binary_op_type(c, beta_scalar);
set_binary_op_output_data(c, beta_scalar, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(c);
encoder.set_input_array(beta_scalar);
encoder.set_output_array(out);
encoder.dispatch([c = array::unsafe_weak_copy(c),
beta_scalar = array::unsafe_weak_copy(beta_scalar),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case float32:
binary_op<float, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case float64:
binary_op<double, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, detail::Multiply>(c, beta_scalar, out, bopt);
break;
case complex64:
binary_op<complex64_t, detail::Multiply>(c, beta_scalar, out, bopt);
break;
default:
throw std::runtime_error(
"[AddMM::eval_cpu] Unsupported dtype for beta scaling");
}
});
encoder.add_temporary(std::move(beta_scalar));
}
return;
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}