mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
fix temporary bug (#752)
This commit is contained in:
parent
420ff2f331
commit
f5f18b704f
@ -31,7 +31,8 @@ void explicit_gemm_conv_1D_gpu(
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
auto zero = array(0, in.dtype());
|
||||
copy_gpu(zero, in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||
@ -68,7 +69,7 @@ void explicit_gemm_conv_1D_gpu(
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
std::vector<array> copies = {zero, in_padded, in_strided};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
@ -213,6 +214,7 @@ void explicit_gemm_conv_2D_gpu(
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
auto zero = array(0, in.dtype());
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
@ -259,7 +261,7 @@ void explicit_gemm_conv_2D_gpu(
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
std::vector<array> copies = {zero, in_padded, in_strided};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
|
Loading…
Reference in New Issue
Block a user