fix temporary bug (#752)

This commit is contained in:
Awni Hannun 2024-02-27 17:44:39 -08:00 committed by GitHub
parent 420ff2f331
commit f5f18b704f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,7 +31,8 @@ void explicit_gemm_conv_1D_gpu(
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
auto zero = array(0, in.dtype());
copy_gpu(zero, in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
@ -68,7 +69,7 @@ void explicit_gemm_conv_1D_gpu(
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
std::vector<array> copies = {zero, in_padded, in_strided};
return steel_matmul(
s,
d,
@ -213,6 +214,7 @@ void explicit_gemm_conv_2D_gpu(
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
auto zero = array(0, in.dtype());
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Pick input slice from padded
@ -259,7 +261,7 @@ void explicit_gemm_conv_2D_gpu(
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
std::vector<array> copies = {zero, in_padded, in_strided};
return steel_matmul(
s,
d,