From f5f18b704fb0a77f6bd56dbaeb687464dcb24bd5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 27 Feb 2024 17:44:39 -0800 Subject: [PATCH] fix temporary bug (#752) --- mlx/backend/metal/conv.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 4ade8da17..add976e6c 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -31,7 +31,8 @@ void explicit_gemm_conv_1D_gpu( array in_padded(padded_shape, in.dtype(), nullptr, {}); // Fill with zeros - copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s); + auto zero = array(0, in.dtype()); + copy_gpu(zero, in_padded, CopyType::Scalar, s); // Pick input slice from padded size_t data_offset = conv_params.pad[0] * in_padded.strides()[1]; @@ -68,7 +69,7 @@ void explicit_gemm_conv_1D_gpu( copy_gpu(in_strided_view, in_strided, CopyType::General, s); // Perform gemm - std::vector copies = {in_padded, in_strided}; + std::vector copies = {zero, in_padded, in_strided}; return steel_matmul( s, d, @@ -213,6 +214,7 @@ void explicit_gemm_conv_2D_gpu( array in_padded(padded_shape, in.dtype(), nullptr, {}); // Fill with zeros + auto zero = array(0, in.dtype()); copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s); // Pick input slice from padded @@ -259,7 +261,7 @@ void explicit_gemm_conv_2D_gpu( copy_gpu(in_strided_view, in_strided, CopyType::General, s); // Perform gemm - std::vector copies = {in_padded, in_strided}; + std::vector copies = {zero, in_padded, in_strided}; return steel_matmul( s, d,