mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Fix alpha not respected when using bias epilogue (#2578)
This commit is contained in:
@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user