[CUDA] Fix alpha not respected when using bias epilogue (#2578)

This commit is contained in:
Cheng
2025-09-10 09:08:01 +09:00
committed by GitHub
parent dde3682b69
commit 44cc5da4bc
6 changed files with 146 additions and 125 deletions

View File

@@ -248,11 +248,19 @@ void CublasGemm::run(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return;
}
@@ -260,7 +268,13 @@ void CublasGemm::run(
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
}
void CublasGemm::run(

View File

@@ -64,7 +64,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha = 1.0f);
void run(
cu::CommandEncoder& encoder,
@@ -87,7 +88,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha);
void run_batched(
cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
nullptr,
alpha);
a_it.step();
b_it.step();
}

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
nullptr,
alpha);
}
void CublasGemm::run_batched(

View File

@@ -41,7 +41,8 @@ void gemm_and_bias(
array& out,
const array& a,
const array& b,
void* bias = nullptr) {
void* bias = nullptr,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
@@ -94,7 +95,8 @@ void gemm_and_bias(
if (bias) {
gemm.set_bias(bias);
}
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace
@@ -169,7 +171,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
out,
a,
b,
c.data<void>());
c.data<void>(),
alpha_);
return;
}