mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
fix temporary bug (#752)
This commit is contained in:
parent
420ff2f331
commit
f5f18b704f
@ -31,7 +31,8 @@ void explicit_gemm_conv_1D_gpu(
|
|||||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
auto zero = array(0, in.dtype());
|
||||||
|
copy_gpu(zero, in_padded, CopyType::Scalar, s);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||||
@ -68,7 +69,7 @@ void explicit_gemm_conv_1D_gpu(
|
|||||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||||
|
|
||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_padded, in_strided};
|
std::vector<array> copies = {zero, in_padded, in_strided};
|
||||||
return steel_matmul(
|
return steel_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
@ -213,6 +214,7 @@ void explicit_gemm_conv_2D_gpu(
|
|||||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
|
auto zero = array(0, in.dtype());
|
||||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
@ -259,7 +261,7 @@ void explicit_gemm_conv_2D_gpu(
|
|||||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||||
|
|
||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_padded, in_strided};
|
std::vector<array> copies = {zero, in_padded, in_strided};
|
||||||
return steel_matmul(
|
return steel_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
|
Loading…
Reference in New Issue
Block a user