mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:
committed by
GitHub
parent
a9f80d60f6
commit
ff4223904d
36
mlx/ops.cpp
36
mlx/ops.cpp
@@ -2878,14 +2878,14 @@ inline std::vector<int> conv_out_shape(
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
|
||||
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -3058,6 +3058,30 @@ array conv2d(
|
||||
s);
|
||||
}
|
||||
|
||||
/** 3D convolution with a filter */
|
||||
array conv3d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_general(
|
||||
/* const array& input = */ in_,
|
||||
/* const array& weight = */ wt_,
|
||||
/* std::vector<int> stride = */
|
||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||
/* std::vector<int> padding = */
|
||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||
/* std::vector<int> kernel_dilation = */
|
||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||
/* std::vector<int> input_dilation = */ {1, 1, 1},
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ false,
|
||||
s);
|
||||
}
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array in,
|
||||
@@ -3078,9 +3102,9 @@ array conv_general(
|
||||
|
||||
int spatial_dims = in.ndim() - 2;
|
||||
|
||||
if (spatial_dims < 1 || spatial_dims > 2) {
|
||||
if (spatial_dims < 1 || spatial_dims > 3) {
|
||||
throw std::invalid_argument(
|
||||
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
|
||||
"[conv] Only works for inputs with 1-3 spatial dimensions."
|
||||
" The inputs must be in the format [N, ..., C_in]");
|
||||
}
|
||||
|
||||
@@ -3120,10 +3144,10 @@ array conv_general(
|
||||
// Check for negative padding
|
||||
bool has_neg_padding = false;
|
||||
for (auto& pd : padding_lo) {
|
||||
has_neg_padding = (pd < 0);
|
||||
has_neg_padding |= (pd < 0);
|
||||
}
|
||||
for (auto& pd : padding_hi) {
|
||||
has_neg_padding = (pd < 0);
|
||||
has_neg_padding |= (pd < 0);
|
||||
}
|
||||
|
||||
// Handle negative padding
|
||||
|
||||
Reference in New Issue
Block a user