mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Fix alpha not respected when using bias epilogue (#2578)
This commit is contained in:
@@ -248,11 +248,19 @@ void CublasGemm::run(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
alpha);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -260,7 +268,13 @@ void CublasGemm::run(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
|
||||
Reference in New Issue
Block a user