mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:
parent
a9f80d60f6
commit
ff4223904d
@ -15,6 +15,7 @@ Layers
|
||||
BatchNorm
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
|
@ -310,6 +310,296 @@ void slow_conv_2D(
|
||||
} // n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
|
||||
const int oD = out.shape(1); // Output spatial dim
|
||||
const int oH = out.shape(2); // Output spatial dim
|
||||
const int oW = out.shape(3); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(4); // In channels
|
||||
const int wD = wt.shape(1); // Weight spatial dim
|
||||
const int wH = wt.shape(2); // Weight spatial dim
|
||||
const int wW = wt.shape(3); // Weight spatial dim
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_D = in.strides()[1];
|
||||
const size_t in_stride_H = in.strides()[2];
|
||||
const size_t in_stride_W = in.strides()[3];
|
||||
const size_t in_stride_C = in.strides()[4];
|
||||
|
||||
const size_t wt_stride_O = wt.strides()[0];
|
||||
const size_t wt_stride_D = wt.strides()[1];
|
||||
const size_t wt_stride_H = wt.strides()[2];
|
||||
const size_t wt_stride_W = wt.strides()[3];
|
||||
const size_t wt_stride_C = wt.strides()[4];
|
||||
|
||||
const size_t out_stride_N = out.strides()[0];
|
||||
const size_t out_stride_D = out.strides()[1];
|
||||
const size_t out_stride_H = out.strides()[2];
|
||||
const size_t out_stride_W = out.strides()[3];
|
||||
const size_t out_stride_O = out.strides()[4];
|
||||
|
||||
bool is_idil_one =
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
|
||||
|
||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = 0; wd < wD; ++wd) {
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
|
||||
|
||||
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
|
||||
|
||||
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
|
||||
|
||||
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
|
||||
|
||||
std::vector<int> base_d(f_out_jump_d);
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||
|
||||
int wd_base = 0;
|
||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||
wd_base++;
|
||||
id_loop += jump_d;
|
||||
}
|
||||
|
||||
base_d[i] = wd_base;
|
||||
}
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
int wd_base = base_d[od % f_out_jump_d];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
|
||||
iw < iW) {
|
||||
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
|
||||
wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
|
||||
|
||||
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
|
||||
ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // iD, ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int oD_border_0 = 0;
|
||||
int oD_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||
int oD_border_2 = std::max(
|
||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
int oD_border_3 = oD;
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: od might put us out of bounds
|
||||
for (int od = oD_border_0; od < oD_border_1; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 2: od in bounds
|
||||
for (int od = oD_border_1; od < oD_border_2; ++od) {
|
||||
// Case 2.1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case 2.2.1: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.2: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.3: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 3: od might put us out of bounds
|
||||
for (int od = oD_border_2; od < oD_border_3; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@ -358,6 +648,30 @@ void dispatch_slow_conv_2D(
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_3D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_3D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_3D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Explicit gemm conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -582,6 +896,131 @@ void explicit_gemm_conv_2D_cpu(
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||
}
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
for (const auto& w : wDim) {
|
||||
strided_reshape[1] *= w;
|
||||
}
|
||||
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Conv routing
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -617,6 +1056,19 @@ void conv_2D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
void conv_3D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
@ -625,8 +1077,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
||||
// 3D convolution
|
||||
if (in.ndim() == (3 + 2)) {
|
||||
return conv_3D_cpu(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
else if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in,
|
||||
wt,
|
||||
|
@ -759,6 +759,56 @@ void conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void conv_3D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
|
||||
/* const int idil[NDIM] = */
|
||||
{in_dilation[0], in_dilation[1], in_dilation[2]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0],
|
||||
in.strides()[1],
|
||||
in.strides()[2],
|
||||
in.strides()[3],
|
||||
in.strides()[4]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0],
|
||||
wt.strides()[1],
|
||||
wt.strides()[2],
|
||||
wt.strides()[3],
|
||||
wt.strides()[4]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0],
|
||||
out.strides()[1],
|
||||
out.strides()[2],
|
||||
out.strides()[3],
|
||||
out.strides()[4]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -783,8 +833,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
// 3D conv
|
||||
if (out.ndim() == 5) {
|
||||
conv_3D_gpu(
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
else if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s,
|
||||
d,
|
||||
|
36
mlx/ops.cpp
36
mlx/ops.cpp
@ -2878,14 +2878,14 @@ inline std::vector<int> conv_out_shape(
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
|
||||
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
|
||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -3058,6 +3058,30 @@ array conv2d(
|
||||
s);
|
||||
}
|
||||
|
||||
/** 3D convolution with a filter */
|
||||
array conv3d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_general(
|
||||
/* const array& input = */ in_,
|
||||
/* const array& weight = */ wt_,
|
||||
/* std::vector<int> stride = */
|
||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||
/* std::vector<int> padding = */
|
||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||
/* std::vector<int> kernel_dilation = */
|
||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||
/* std::vector<int> input_dilation = */ {1, 1, 1},
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ false,
|
||||
s);
|
||||
}
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array in,
|
||||
@ -3078,9 +3102,9 @@ array conv_general(
|
||||
|
||||
int spatial_dims = in.ndim() - 2;
|
||||
|
||||
if (spatial_dims < 1 || spatial_dims > 2) {
|
||||
if (spatial_dims < 1 || spatial_dims > 3) {
|
||||
throw std::invalid_argument(
|
||||
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
|
||||
"[conv] Only works for inputs with 1-3 spatial dimensions."
|
||||
" The inputs must be in the format [N, ..., C_in]");
|
||||
}
|
||||
|
||||
@ -3120,10 +3144,10 @@ array conv_general(
|
||||
// Check for negative padding
|
||||
bool has_neg_padding = false;
|
||||
for (auto& pd : padding_lo) {
|
||||
has_neg_padding = (pd < 0);
|
||||
has_neg_padding |= (pd < 0);
|
||||
}
|
||||
for (auto& pd : padding_hi) {
|
||||
has_neg_padding = (pd < 0);
|
||||
has_neg_padding |= (pd < 0);
|
||||
}
|
||||
|
||||
// Handle negative padding
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -1120,6 +1120,16 @@ array conv2d(
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 3D convolution with a filter */
|
||||
array conv3d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
|
@ -894,20 +894,62 @@ std::vector<array> Convolution::vjp(
|
||||
padding_hi[i] = in_size - out_size + padding_[i];
|
||||
}
|
||||
|
||||
// Check for negative padding
|
||||
bool has_neg_padding = false;
|
||||
for (auto& pd : padding_lo) {
|
||||
has_neg_padding |= (pd < 0);
|
||||
}
|
||||
for (auto& pd : padding_hi) {
|
||||
has_neg_padding |= (pd < 0);
|
||||
}
|
||||
|
||||
auto padding_lo_ = std::vector<int>(padding_lo);
|
||||
auto padding_hi_ = std::vector<int>(padding_hi);
|
||||
|
||||
// Use negative padding on the gradient output
|
||||
if (has_neg_padding) {
|
||||
for (auto& p : padding_lo_) {
|
||||
p = std::max(0, p);
|
||||
}
|
||||
for (auto& p : padding_hi_) {
|
||||
p = std::max(0, p);
|
||||
}
|
||||
}
|
||||
|
||||
auto wt_trans = swapaxes(wt, 0, -1, stream());
|
||||
|
||||
auto grad = conv_general(
|
||||
/* const array& input = */ cotan,
|
||||
/* const array& weight = */ wt_trans,
|
||||
/* std::vector<int> stride = */ input_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> padding_lo = */ padding_lo_,
|
||||
/* std::vector<int> padding_hi = */ padding_hi_,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* bool flip = */ !flip_,
|
||||
stream());
|
||||
|
||||
// Handle negative padding
|
||||
if (has_neg_padding) {
|
||||
std::vector<int> starts(grad.ndim(), 0);
|
||||
std::vector<int> stops = grad.shape();
|
||||
|
||||
for (int i = 0; i < grad.ndim() - 2; i++) {
|
||||
if (padding_lo[i] < 0) {
|
||||
starts[i + 1] -= padding_lo[i];
|
||||
padding_lo[i] = 0;
|
||||
}
|
||||
|
||||
if (padding_hi[i] < 0) {
|
||||
stops[i + 1] += padding_hi[i];
|
||||
padding_hi[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
grad = slice(grad, std::move(starts), std::move(stops), stream());
|
||||
}
|
||||
|
||||
grads.push_back(grad);
|
||||
}
|
||||
// Grads for weight
|
||||
|
@ -48,7 +48,7 @@ from mlx.nn.layers.activations import (
|
||||
)
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
|
@ -132,3 +132,66 @@ class Conv2d(Module):
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class Conv3d(Module):
|
||||
"""Applies a 3-dimensional convolution over the multi-channel input image.
|
||||
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``D`` is the input image depth
|
||||
- ``H`` is the input image height
|
||||
- ``W`` is the input image width
|
||||
- ``C`` is the number of input channels
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: ``1``.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(
|
||||
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||
)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv3d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
@ -3230,6 +3230,78 @@ void init_ops(nb::module_& m) {
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv3d",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&dilation); pv) {
|
||||
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||
}
|
||||
|
||||
return conv3d(
|
||||
input,
|
||||
weight,
|
||||
stride_tuple,
|
||||
padding_tuple,
|
||||
dilation_tuple,
|
||||
groups,
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
3D convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): input array of shape ``(N, D, H, W, C_in)``
|
||||
weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
symmetric input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_general",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
|
@ -399,7 +399,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
pt_grad_in = F.grad.conv2d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
@ -408,7 +408,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
pt_grad_wt = F.grad.conv2d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
@ -444,6 +444,203 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_3D(self):
|
||||
def run_conv3D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv3d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 16, 32),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),
|
||||
):
|
||||
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_3D_grad(self):
|
||||
def run_conv3D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
|
||||
oD = 1 + (
|
||||
(iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0]
|
||||
)
|
||||
oH = 1 + (
|
||||
(iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1]
|
||||
)
|
||||
oW = 1 + (
|
||||
(iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2]
|
||||
)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv3d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv3d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv3d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy()
|
||||
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
):
|
||||
run_conv3D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
def __conv_general_test(
|
||||
self,
|
||||
in_shape,
|
||||
|
Loading…
Reference in New Issue
Block a user