diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 6fb624d54..cbbbb5c3b 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -15,6 +15,7 @@ Layers BatchNorm Conv1d Conv2d + Conv3d Dropout Dropout2d Dropout3d diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index f3162c056..0b19e60e3 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -310,6 +310,296 @@ void slow_conv_2D( } // n } +template +void slow_conv_3D( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { + const T* st_wt_ptr = wt.data(); + const T* st_in_ptr = in.data(); + T* st_out_ptr = out.data(); + + const int N = in.shape(0); // Batch size, should be the same as out.shape(0) + const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim + const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim + const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim + const int oD = out.shape(1); // Output spatial dim + const int oH = out.shape(2); // Output spatial dim + const int oW = out.shape(3); // Output spatial dim + const int O = wt.shape(0); // Out channels + const int C = wt.shape(4); // In channels + const int wD = wt.shape(1); // Weight spatial dim + const int wH = wt.shape(2); // Weight spatial dim + const int wW = wt.shape(3); // Weight spatial dim + + const size_t in_stride_N = in.strides()[0]; + const size_t in_stride_D = in.strides()[1]; + const size_t in_stride_H = in.strides()[2]; + const size_t in_stride_W = in.strides()[3]; + const size_t in_stride_C = in.strides()[4]; + + const size_t wt_stride_O = wt.strides()[0]; + const size_t wt_stride_D = wt.strides()[1]; + const size_t wt_stride_H = wt.strides()[2]; + const size_t wt_stride_W = wt.strides()[3]; + const size_t wt_stride_C = wt.strides()[4]; + + const size_t out_stride_N = out.strides()[0]; + const size_t out_stride_D = out.strides()[1]; + const size_t out_stride_H = out.strides()[2]; + const size_t out_stride_W = out.strides()[3]; + const size_t out_stride_O = out.strides()[4]; + + bool is_idil_one = + in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1; + + auto pt_conv_no_checks = [&](const T* in_ptr, + const T* wt_ptr, + T* out_ptr, + int od, + int oh, + int ow) { + out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; + int id_base = od * wt_strides[0] - padding[0]; + int ih_base = oh * wt_strides[1] - padding[1]; + int iw_base = ow * wt_strides[2] - padding[2]; + + for (int o = 0; o < O; ++o) { + float r = 0.; + + for (int wd = 0; wd < wD; ++wd) { + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wd_flip = flip ? wD - wd - 1 : wd; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int id = id_base + wd_flip * wt_dilation[0]; + int ih = ih_base + wh_flip * wt_dilation[1]; + int iw = iw_base + ww_flip * wt_dilation[2]; + + const T* wt_ptr_pt = + wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W; + + for (int c = 0; c < C; ++c) { + r += static_cast(in_ptr_pt[0]) * + static_cast(wt_ptr_pt[0]); + in_ptr_pt += in_stride_C; + wt_ptr_pt += wt_stride_C; + } // c + + } // ww + } // wh + } // wd + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + }; + + int jump_d = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_h = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_w = flip ? -wt_dilation[2] : wt_dilation[2]; + + int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0); + + int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2]; + + int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2]; + + std::vector base_d(f_out_jump_d); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); + + for (int i = 0; i < f_out_jump_d; ++i) { + int id_loop = i * wt_strides[0] - padding[0] + init_d; + + int wd_base = 0; + while (wd_base < wD && id_loop % in_dilation[0] != 0) { + wd_base++; + id_loop += jump_d; + } + + base_d[i] = wd_base; + } + + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[1] - padding[1] + init_h; + + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[1] != 0) { + wh_base++; + ih_loop += jump_h; + } + + base_h[i] = wh_base; + } + + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[2] - padding[2] + init_w; + + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[2] != 0) { + ww_base++; + iw_loop += jump_w; + } + + base_w[j] = ww_base; + } + + auto pt_conv_all_checks = [&](const T* in_ptr, + const T* wt_ptr, + T* out_ptr, + int od, + int oh, + int ow) { + out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; + + int id_base = od * wt_strides[0] - padding[0]; + int ih_base = oh * wt_strides[1] - padding[1]; + int iw_base = ow * wt_strides[2] - padding[2]; + + int wd_base = base_d[od % f_out_jump_d]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; + + for (int o = 0; o < O; ++o) { + float r = 0.; + + for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wd_flip = flip ? wD - wd - 1 : wd; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int id = id_base + wd_flip * wt_dilation[0]; + int ih = ih_base + wh_flip * wt_dilation[1]; + int iw = iw_base + ww_flip * wt_dilation[2]; + + if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && + iw < iW) { + const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + + wh * wt_stride_H + ww * wt_stride_W; + + int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; + int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; + + const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + + ih_dil * in_stride_H + iw_dil * in_stride_W; + + for (int c = 0; c < C; ++c) { + r += static_cast(in_ptr_pt[0]) * + static_cast(wt_ptr_pt[0]); + in_ptr_pt += in_stride_C; + wt_ptr_pt += wt_stride_C; + } // c + + } // iD, ih, iw check + } // ww + } // wh + } // wd + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + }; + + int oD_border_0 = 0; + int oD_border_1 = + is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_2 = std::max( + oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + int oD_border_3 = oD; + + int oH_border_0 = 0; + int oH_border_1 = + is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_2 = std::max( + oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + int oH_border_3 = oH; + + int oW_border_0 = 0; + int oW_border_1 = + is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_2 = std::max( + oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + int oW_border_3 = oW; + + for (int n = 0; n < N; ++n) { + // Case 1: od might put us out of bounds + for (int od = oD_border_0; od < oD_border_1; ++od) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + } // od + + // Case 2: od in bounds + for (int od = oD_border_1; od < oD_border_2; ++od) { + // Case 2.1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + + // Case 2.2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case 2.2.1: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + + // Case 2.2.2: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + + // Case 2.2.3: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + + // Case 2.3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + } // od + + // Case 3: od might put us out of bounds + for (int od = oD_border_2; od < oD_border_3; ++od) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); + } // ow + } // oh + } // od + + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; + + } // n +} + void dispatch_slow_conv_1D( const array& in, const array& wt, @@ -358,6 +648,30 @@ void dispatch_slow_conv_2D( } } +void dispatch_slow_conv_3D( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { + if (in.dtype() == float32) { + return slow_conv_3D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + } else if (in.dtype() == float16) { + return slow_conv_3D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + } else if (in.dtype() == bfloat16) { + return slow_conv_3D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); + } else { + throw std::invalid_argument( + "[Convolution::eval] got unsupported data type."); + } +} + /////////////////////////////////////////////////////////////////////////////// // Explicit gemm conv /////////////////////////////////////////////////////////////////////////////// @@ -582,6 +896,131 @@ void explicit_gemm_conv_2D_cpu( } } +void explicit_gemm_conv_ND_cpu( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + const int N = in.shape(0); // Batch size, should be the same as out.shape(0) + const auto iDim = std::vector( + in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim + const auto oDim = std::vector( + out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim + const int O = wt.shape(0); // Out channels + const int C = wt.shape(-1); // In channels + const auto wDim = std::vector( + wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim + + auto conv_dtype = float32; + + // Pad input + std::vector padded_shape(in.shape().size()); + padded_shape.front() = N; + for (size_t i = 0; i < iDim.size(); i++) { + padded_shape[i + 1] = iDim[i] + 2 * padding[i]; + } + padded_shape.back() = C; + array in_padded(padded_shape, conv_dtype, nullptr, {}); + + // Fill with zeros + copy(array(0, conv_dtype), in_padded, CopyType::Scalar); + + // Pick input slice from padded + size_t data_offset = 0; + for (size_t i = 0; i < padding.size(); i++) { + data_offset += padding[i] * in_padded.strides()[i + 1]; + } + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); + + // Make strided view + std::vector strided_shape(oDim.size() + wDim.size() + 2); + strided_shape.front() = N; + for (size_t i = 0; i < oDim.size(); i++) { + strided_shape[i + 1] = oDim[i]; + } + for (size_t i = 0; i < wDim.size(); i++) { + strided_shape[i + 1 + oDim.size()] = wDim[i]; + } + strided_shape.back() = C; + + std::vector strided_strides(in.shape().size() * 2 - 2); + strided_strides[0] = in_padded.strides()[0]; + for (size_t i = 0; i < wt_strides.size(); i++) { + strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i]; + } + for (size_t i = 1; i < in_padded.strides().size(); i++) { + strided_strides[i + wt_strides.size()] = in_padded.strides()[i]; + } + + auto flags = in_padded.flags(); + + array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); + in_strided_view.copy_shared_buffer( + in_padded, strided_strides, flags, in_strided_view.size(), 0); + + // Materialize strided view + std::vector strided_reshape = {N, C}; + for (const auto& o : oDim) { + strided_reshape[0] *= o; + } + for (const auto& w : wDim) { + strided_reshape[1] *= w; + } + + array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); + copy(in_strided_view, in_strided, CopyType::General); + + // Check wt dtype and prepare + auto gemm_wt = wt; + auto gemm_out = out; + + if (wt.dtype() != float32 || !wt.flags().row_contiguous) { + auto ctype = + wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; + gemm_wt = array(wt.shape(), float32, nullptr, {}); + copy(wt, gemm_wt, ctype); + } + + if (out.dtype() != float32) { + gemm_out = array(out.shape(), float32, nullptr, {}); + gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + } + + // Perform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O, // N + strided_reshape[1], // K + 1.0f, // alpha + in_strided.data(), + strided_reshape[1], // lda + gemm_wt.data(), + strided_reshape[1], // ldb + 0.0f, // beta + gemm_out.data(), + O // ldc + ); + + // Copy results if needed + if (out.dtype() != float32) { + copy(gemm_out, out, CopyType::Vector); + } +} + /////////////////////////////////////////////////////////////////////////////// // Conv routing /////////////////////////////////////////////////////////////////////////////// @@ -617,6 +1056,19 @@ void conv_2D_cpu( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } +void conv_3D_cpu( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { + return dispatch_slow_conv_3D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); +} + } // namespace void Convolution::eval(const std::vector& inputs, array& out) { @@ -625,8 +1077,20 @@ void Convolution::eval(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto& wt = inputs[1]; + // 3D convolution + if (in.ndim() == (3 + 2)) { + return conv_3D_cpu( + in, + wt, + out, + padding_, + kernel_strides_, + kernel_dilation_, + input_dilation_, + flip_); + } // 2D convolution - if (in.ndim() == (2 + 2)) { + else if (in.ndim() == (2 + 2)) { return conv_2D_cpu( in, wt, diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 165d66050..7fcb6592e 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -759,6 +759,56 @@ void conv_2D_gpu( } } +void conv_3D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip, + std::vector& copies) { + // Make conv params + MLXConvParams<3> conv_params{ + /* const int N = */ in.shape(0), + /* const int C = */ in.shape(4), + /* const int O = */ wt.shape(0), + /* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)}, + /* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)}, + /* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)}, + /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]}, + /* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]}, + /* const int kdil[NDIM] = */ + {wt_dilation[0], wt_dilation[1], wt_dilation[2]}, + /* const int idil[NDIM] = */ + {in_dilation[0], in_dilation[1], in_dilation[2]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], + in.strides()[1], + in.strides()[2], + in.strides()[3], + in.strides()[4]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], + wt.strides()[1], + wt.strides()[2], + wt.strides()[3], + wt.strides()[4]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], + out.strides()[1], + out.strides()[2], + out.strides()[3], + out.strides()[4]}, + /* const int groups = */ 1, + /* const bool flip = */ flip, + }; + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); +} + } // namespace void Convolution::eval_gpu(const std::vector& inputs, array& out) { @@ -783,8 +833,23 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { wt = arr_copy; } + // 3D conv + if (out.ndim() == 5) { + conv_3D_gpu( + s, + d, + in, + wt, + out, + padding_, + kernel_strides_, + kernel_dilation_, + input_dilation_, + flip_, + copies); + } // 2D conv - if (out.ndim() == 4) { + else if (out.ndim() == 4) { conv_2D_gpu( s, d, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7ab013f4e..6a468e771 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2878,14 +2878,14 @@ inline std::vector conv_out_shape( if (strides.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid strides " << strides << "for " << spatial_dims + msg << "[conv] Invalid strides " << strides << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for " + msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } @@ -3058,6 +3058,30 @@ array conv2d( s); } +/** 3D convolution with a filter */ +array conv3d( + const array& in_, + const array& wt_, + const std::tuple& stride /* = {1, 1, 1} */, + const std::tuple& padding /* = {0, 0, 0} */, + const std::tuple& dilation /* = {1, 1, 1} */, + int groups /* = 1 */, + StreamOrDevice s /* = {} */) { + return conv_general( + /* const array& input = */ in_, + /* const array& weight = */ wt_, + /* std::vector stride = */ + {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, + /* std::vector padding = */ + {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, + /* std::vector kernel_dilation = */ + {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + /* std::vector input_dilation = */ {1, 1, 1}, + /* int groups = */ groups, + /* bool flip = */ false, + s); +} + /** General convolution with a filter */ array conv_general( array in, @@ -3078,9 +3102,9 @@ array conv_general( int spatial_dims = in.ndim() - 2; - if (spatial_dims < 1 || spatial_dims > 2) { + if (spatial_dims < 1 || spatial_dims > 3) { throw std::invalid_argument( - "[conv] Can only work with inputs that have 1 or 2 spatial dimensions." + "[conv] Only works for inputs with 1-3 spatial dimensions." " The inputs must be in the format [N, ..., C_in]"); } @@ -3120,10 +3144,10 @@ array conv_general( // Check for negative padding bool has_neg_padding = false; for (auto& pd : padding_lo) { - has_neg_padding = (pd < 0); + has_neg_padding |= (pd < 0); } for (auto& pd : padding_hi) { - has_neg_padding = (pd < 0); + has_neg_padding |= (pd < 0); } // Handle negative padding diff --git a/mlx/ops.h b/mlx/ops.h index 2df60362c..c43437f13 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1120,6 +1120,16 @@ array conv2d( int groups = 1, StreamOrDevice s = {}); +/** 3D convolution with a filter */ +array conv3d( + const array& input, + const array& weight, + const std::tuple& stride = {1, 1, 1}, + const std::tuple& padding = {0, 0, 0}, + const std::tuple& dilation = {1, 1, 1}, + int groups = 1, + StreamOrDevice s = {}); + /** Quantized matmul multiplies x with a quantized matrix w*/ array quantized_matmul( const array& x, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index d9c0739f0..b938c6afb 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -894,20 +894,62 @@ std::vector Convolution::vjp( padding_hi[i] = in_size - out_size + padding_[i]; } + // Check for negative padding + bool has_neg_padding = false; + for (auto& pd : padding_lo) { + has_neg_padding |= (pd < 0); + } + for (auto& pd : padding_hi) { + has_neg_padding |= (pd < 0); + } + + auto padding_lo_ = std::vector(padding_lo); + auto padding_hi_ = std::vector(padding_hi); + + // Use negative padding on the gradient output + if (has_neg_padding) { + for (auto& p : padding_lo_) { + p = std::max(0, p); + } + for (auto& p : padding_hi_) { + p = std::max(0, p); + } + } + auto wt_trans = swapaxes(wt, 0, -1, stream()); auto grad = conv_general( /* const array& input = */ cotan, /* const array& weight = */ wt_trans, /* std::vector stride = */ input_dilation_, - /* std::vector padding_lo = */ padding_lo, - /* std::vector padding_hi = */ padding_hi, + /* std::vector padding_lo = */ padding_lo_, + /* std::vector padding_hi = */ padding_hi_, /* std::vector kernel_dilation = */ kernel_dilation_, /* std::vector input_dilation = */ kernel_strides_, /* int groups = */ 1, /* bool flip = */ !flip_, stream()); + // Handle negative padding + if (has_neg_padding) { + std::vector starts(grad.ndim(), 0); + std::vector stops = grad.shape(); + + for (int i = 0; i < grad.ndim() - 2; i++) { + if (padding_lo[i] < 0) { + starts[i + 1] -= padding_lo[i]; + padding_lo[i] = 0; + } + + if (padding_hi[i] < 0) { + stops[i + 1] += padding_hi[i]; + padding_hi[i] = 0; + } + } + + grad = slice(grad, std::move(starts), std::move(stops), stream()); + } + grads.push_back(grad); } // Grads for weight diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index fce721a06..be4935626 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -48,7 +48,7 @@ from mlx.nn.layers.activations import ( ) from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential -from mlx.nn.layers.convolution import Conv1d, Conv2d +from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index 6e1c9780e..0126c6d2a 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -132,3 +132,66 @@ class Conv2d(Module): if "bias" in self: y = y + self.bias return y + + +class Conv3d(Module): + """Applies a 3-dimensional convolution over the multi-channel input image. + The channels are expected to be last i.e. the input shape should be ``NDHWC`` where: + - ``N`` is the batch dimension + - ``D`` is the input image depth + - ``H`` is the input image height + - ``W`` is the input image width + - ``C`` is the number of input channels + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int or tuple): The size of the convolution filters. + stride (int or tuple, optional): The size of the stride when + applying the filter. Default: ``1``. + padding (int or tuple, optional): How many positions to 0-pad + the input with. Default: ``0``. + bias (bool, optional): If ``True`` add a learnable bias to the + output. Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, tuple], + stride: Union[int, tuple] = 1, + padding: Union[int, tuple] = 0, + bias: bool = True, + ): + super().__init__() + + kernel_size, stride, padding = map( + lambda x: (x, x, x) if isinstance(x, int) else x, + (kernel_size, stride, padding), + ) + scale = math.sqrt( + 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) + ) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, *kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.stride = stride + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " + f"padding={self.padding}, bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv3d(x, self.weight, self.stride, self.padding) + if "bias" in self: + y = y + self.bias + return y diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 551d7ddda..b468be0b0 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3230,6 +3230,78 @@ void init_ops(nb::module_& m) { array: The convolved array. )pbdoc"); m.def( + "conv3d", + [](const array& input, + const array& weight, + const std::variant>& stride, + const std::variant>& padding, + const std::variant>& dilation, + int groups, + StreamOrDevice s) { + std::tuple stride_tuple{1, 1, 1}; + std::tuple padding_tuple{0, 0, 0}; + std::tuple dilation_tuple{1, 1, 1}; + + if (auto pv = std::get_if(&stride); pv) { + stride_tuple = std::tuple{*pv, *pv, *pv}; + } else { + stride_tuple = std::get>(stride); + } + + if (auto pv = std::get_if(&padding); pv) { + padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + padding_tuple = std::get>(padding); + } + + if (auto pv = std::get_if(&dilation); pv) { + dilation_tuple = std::tuple{*pv, *pv, *pv}; + } else { + dilation_tuple = std::get>(dilation); + } + + return conv3d( + input, + weight, + stride_tuple, + padding_tuple, + dilation_tuple, + groups, + s); + }, + nb::arg(), + nb::arg(), + "stride"_a = 1, + "padding"_a = 0, + "dilation"_a = 1, + "groups"_a = 1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + 3D convolution over an input with several channels + + Note: Only the default ``groups=1`` is currently supported. + + Args: + input (array): input array of shape ``(N, D, H, W, C_in)`` + weight (array): weight array of shape ``(C_out, D, H, W, C_in)`` + stride (int or tuple(int), optional): :obj:`tuple` of size 3 with + kernel strides. All spatial dimensions get the same stride if + only one number is specified. Default: ``1``. + padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + symmetric input padding. All spatial dimensions get the same + padding if only one number is specified. Default: ``0``. + dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with + kernel dilation. All spatial dimensions get the same dilation + if only one number is specified. Default: ``1`` + groups (int, optional): input feature groups. Default: ``1``. + + Returns: + array: The convolved array. + )pbdoc"); + m.def( "conv_general", [](const array& input, const array& weight, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 8c5585126..1111ce80a 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -399,7 +399,7 @@ class TestConv(mlx_tests.MLXTestCase): [in_mx, wt_mx], [ct_mx], ) - pt_grad_in = F.grad.conv1d_input( + pt_grad_in = F.grad.conv2d_input( in_pt.shape, wt_pt, ct_pt, @@ -408,7 +408,7 @@ class TestConv(mlx_tests.MLXTestCase): dilation=dilation, groups=groups, ) - pt_grad_wt = F.grad.conv1d_weight( + pt_grad_wt = F.grad.conv2d_weight( in_pt, wt_pt.shape, ct_pt, @@ -444,6 +444,203 @@ class TestConv(mlx_tests.MLXTestCase): N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_3D(self): + def run_conv3D( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + scale = 1.0 / math.sqrt(kD * kH * kW * C) + in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt, wt_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"), + (in_np, wt_np), + ) + + out_mx = mx.conv3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.conv3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ( + (1, 1, 1), + (1, 6, 1), + (1, 1, 6), + (4, 16, 32), + ): + for idim, kdim, stride, padding in ( + ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)), + ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)), + ((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)), + ): + run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_3D_grad(self): + def run_conv3D_grad( + N, + C, + O, + idim, + kdim, + stride, + padding, + dilation=(1, 1, 1), + groups=1, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + scale = 1.0 / math.sqrt(kD * kH * kW * C) + + oD = 1 + ( + (iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0] + ) + oH = 1 + ( + (iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1] + ) + oW = 1 + ( + (iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2] + ) + + in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype( + np_dtype + ) + ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype( + np_dtype + ) + + in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np)) + in_pt, wt_pt, ct_pt = map( + lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"), + (in_np, wt_np, ct_np), + ) + + def f(a, b): + return mx.conv3d( + a, + b, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + _, outs_mx = mx.vjp( + f, + [in_mx, wt_mx], + [ct_mx], + ) + pt_grad_in = F.grad.conv3d_input( + in_pt.shape, + wt_pt, + ct_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + pt_grad_wt = F.grad.conv3d_weight( + in_pt, + wt_pt.shape, + ct_pt, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy() + pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy() + + mx_grad_in, mx_grad_wt = outs_mx + + self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) + self.assertEqual(in_mx.shape, mx_grad_in.shape) + self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) + + self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) + self.assertEqual(wt_mx.shape, mx_grad_wt.shape) + self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)): + for idim, kdim, stride, padding, dilation in ( + ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), + ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), + ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)), + ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)), + ((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)), + ): + run_conv3D_grad( + N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype + ) + def __conv_general_test( self, in_shape,