mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 11:31:21 +08:00
Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:
parent
a9f80d60f6
commit
ff4223904d
@ -15,6 +15,7 @@ Layers
|
|||||||
BatchNorm
|
BatchNorm
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
|
Conv3d
|
||||||
Dropout
|
Dropout
|
||||||
Dropout2d
|
Dropout2d
|
||||||
Dropout3d
|
Dropout3d
|
||||||
|
@ -310,6 +310,296 @@ void slow_conv_2D(
|
|||||||
} // n
|
} // n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void slow_conv_3D(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
const std::vector<int>& in_dilation,
|
||||||
|
bool flip) {
|
||||||
|
const T* st_wt_ptr = wt.data<T>();
|
||||||
|
const T* st_in_ptr = in.data<T>();
|
||||||
|
T* st_out_ptr = out.data<T>();
|
||||||
|
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||||
|
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||||
|
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
|
||||||
|
const int oD = out.shape(1); // Output spatial dim
|
||||||
|
const int oH = out.shape(2); // Output spatial dim
|
||||||
|
const int oW = out.shape(3); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(4); // In channels
|
||||||
|
const int wD = wt.shape(1); // Weight spatial dim
|
||||||
|
const int wH = wt.shape(2); // Weight spatial dim
|
||||||
|
const int wW = wt.shape(3); // Weight spatial dim
|
||||||
|
|
||||||
|
const size_t in_stride_N = in.strides()[0];
|
||||||
|
const size_t in_stride_D = in.strides()[1];
|
||||||
|
const size_t in_stride_H = in.strides()[2];
|
||||||
|
const size_t in_stride_W = in.strides()[3];
|
||||||
|
const size_t in_stride_C = in.strides()[4];
|
||||||
|
|
||||||
|
const size_t wt_stride_O = wt.strides()[0];
|
||||||
|
const size_t wt_stride_D = wt.strides()[1];
|
||||||
|
const size_t wt_stride_H = wt.strides()[2];
|
||||||
|
const size_t wt_stride_W = wt.strides()[3];
|
||||||
|
const size_t wt_stride_C = wt.strides()[4];
|
||||||
|
|
||||||
|
const size_t out_stride_N = out.strides()[0];
|
||||||
|
const size_t out_stride_D = out.strides()[1];
|
||||||
|
const size_t out_stride_H = out.strides()[2];
|
||||||
|
const size_t out_stride_W = out.strides()[3];
|
||||||
|
const size_t out_stride_O = out.strides()[4];
|
||||||
|
|
||||||
|
bool is_idil_one =
|
||||||
|
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
|
||||||
|
|
||||||
|
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||||
|
const T* wt_ptr,
|
||||||
|
T* out_ptr,
|
||||||
|
int od,
|
||||||
|
int oh,
|
||||||
|
int ow) {
|
||||||
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
|
int id_base = od * wt_strides[0] - padding[0];
|
||||||
|
int ih_base = oh * wt_strides[1] - padding[1];
|
||||||
|
int iw_base = ow * wt_strides[2] - padding[2];
|
||||||
|
|
||||||
|
for (int o = 0; o < O; ++o) {
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
|
for (int wd = 0; wd < wD; ++wd) {
|
||||||
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
|
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||||
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
|
int id = id_base + wd_flip * wt_dilation[0];
|
||||||
|
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||||
|
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||||
|
|
||||||
|
const T* wt_ptr_pt =
|
||||||
|
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
const T* in_ptr_pt =
|
||||||
|
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
|
||||||
|
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
r += static_cast<float>(in_ptr_pt[0]) *
|
||||||
|
static_cast<float>(wt_ptr_pt[0]);
|
||||||
|
in_ptr_pt += in_stride_C;
|
||||||
|
wt_ptr_pt += wt_stride_C;
|
||||||
|
} // c
|
||||||
|
|
||||||
|
} // ww
|
||||||
|
} // wh
|
||||||
|
} // wd
|
||||||
|
|
||||||
|
out_ptr[0] = static_cast<T>(r);
|
||||||
|
out_ptr += out_stride_O;
|
||||||
|
wt_ptr += wt_stride_O;
|
||||||
|
} // o
|
||||||
|
};
|
||||||
|
|
||||||
|
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||||
|
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||||
|
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
|
||||||
|
|
||||||
|
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
|
||||||
|
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
|
||||||
|
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
|
||||||
|
|
||||||
|
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||||
|
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||||
|
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
|
||||||
|
|
||||||
|
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||||
|
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||||
|
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
|
||||||
|
|
||||||
|
std::vector<int> base_d(f_out_jump_d);
|
||||||
|
std::vector<int> base_h(f_out_jump_h);
|
||||||
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
|
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||||
|
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||||
|
|
||||||
|
int wd_base = 0;
|
||||||
|
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||||
|
wd_base++;
|
||||||
|
id_loop += jump_d;
|
||||||
|
}
|
||||||
|
|
||||||
|
base_d[i] = wd_base;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
|
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||||
|
|
||||||
|
int wh_base = 0;
|
||||||
|
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||||
|
wh_base++;
|
||||||
|
ih_loop += jump_h;
|
||||||
|
}
|
||||||
|
|
||||||
|
base_h[i] = wh_base;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
|
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||||
|
|
||||||
|
int ww_base = 0;
|
||||||
|
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||||
|
ww_base++;
|
||||||
|
iw_loop += jump_w;
|
||||||
|
}
|
||||||
|
|
||||||
|
base_w[j] = ww_base;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pt_conv_all_checks = [&](const T* in_ptr,
|
||||||
|
const T* wt_ptr,
|
||||||
|
T* out_ptr,
|
||||||
|
int od,
|
||||||
|
int oh,
|
||||||
|
int ow) {
|
||||||
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
|
int id_base = od * wt_strides[0] - padding[0];
|
||||||
|
int ih_base = oh * wt_strides[1] - padding[1];
|
||||||
|
int iw_base = ow * wt_strides[2] - padding[2];
|
||||||
|
|
||||||
|
int wd_base = base_d[od % f_out_jump_d];
|
||||||
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
|
int ww_base = base_w[ow % f_out_jump_w];
|
||||||
|
|
||||||
|
for (int o = 0; o < O; ++o) {
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
|
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
|
||||||
|
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||||
|
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||||
|
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||||
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
|
int id = id_base + wd_flip * wt_dilation[0];
|
||||||
|
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||||
|
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||||
|
|
||||||
|
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
|
||||||
|
iw < iW) {
|
||||||
|
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
|
||||||
|
wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
|
||||||
|
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
|
||||||
|
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
|
||||||
|
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
|
||||||
|
|
||||||
|
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
|
||||||
|
ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||||
|
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
r += static_cast<float>(in_ptr_pt[0]) *
|
||||||
|
static_cast<float>(wt_ptr_pt[0]);
|
||||||
|
in_ptr_pt += in_stride_C;
|
||||||
|
wt_ptr_pt += wt_stride_C;
|
||||||
|
} // c
|
||||||
|
|
||||||
|
} // iD, ih, iw check
|
||||||
|
} // ww
|
||||||
|
} // wh
|
||||||
|
} // wd
|
||||||
|
|
||||||
|
out_ptr[0] = static_cast<T>(r);
|
||||||
|
out_ptr += out_stride_O;
|
||||||
|
wt_ptr += wt_stride_O;
|
||||||
|
} // o
|
||||||
|
};
|
||||||
|
|
||||||
|
int oD_border_0 = 0;
|
||||||
|
int oD_border_1 =
|
||||||
|
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||||
|
int oD_border_2 = std::max(
|
||||||
|
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||||
|
int oD_border_3 = oD;
|
||||||
|
|
||||||
|
int oH_border_0 = 0;
|
||||||
|
int oH_border_1 =
|
||||||
|
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||||
|
int oH_border_2 = std::max(
|
||||||
|
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||||
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
|
int oW_border_0 = 0;
|
||||||
|
int oW_border_1 =
|
||||||
|
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||||
|
int oW_border_2 = std::max(
|
||||||
|
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||||
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
// Case 1: od might put us out of bounds
|
||||||
|
for (int od = oD_border_0; od < oD_border_1; ++od) {
|
||||||
|
for (int oh = 0; oh < oH; ++oh) {
|
||||||
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
} // od
|
||||||
|
|
||||||
|
// Case 2: od in bounds
|
||||||
|
for (int od = oD_border_1; od < oD_border_2; ++od) {
|
||||||
|
// Case 2.1: oh might put us out of bounds
|
||||||
|
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||||
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
|
||||||
|
// Case 2.2: oh in bounds
|
||||||
|
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||||
|
// Case 2.2.1: ow might put us out of bounds
|
||||||
|
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
|
||||||
|
// Case 2.2.2: ow in bounds
|
||||||
|
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||||
|
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
|
||||||
|
// Case 2.2.3: ow might put us out of bounds
|
||||||
|
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
|
||||||
|
// Case 2.3: oh might put us out of bounds
|
||||||
|
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||||
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
} // od
|
||||||
|
|
||||||
|
// Case 3: od might put us out of bounds
|
||||||
|
for (int od = oD_border_2; od < oD_border_3; ++od) {
|
||||||
|
for (int oh = 0; oh < oH; ++oh) {
|
||||||
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||||
|
} // ow
|
||||||
|
} // oh
|
||||||
|
} // od
|
||||||
|
|
||||||
|
st_in_ptr += in_stride_N;
|
||||||
|
st_out_ptr += out_stride_N;
|
||||||
|
|
||||||
|
} // n
|
||||||
|
}
|
||||||
|
|
||||||
void dispatch_slow_conv_1D(
|
void dispatch_slow_conv_1D(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
@ -358,6 +648,30 @@ void dispatch_slow_conv_2D(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dispatch_slow_conv_3D(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
const std::vector<int>& in_dilation,
|
||||||
|
bool flip) {
|
||||||
|
if (in.dtype() == float32) {
|
||||||
|
return slow_conv_3D<float>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
|
} else if (in.dtype() == float16) {
|
||||||
|
return slow_conv_3D<float16_t>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
|
} else if (in.dtype() == bfloat16) {
|
||||||
|
return slow_conv_3D<bfloat16_t>(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Convolution::eval] got unsupported data type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Explicit gemm conv
|
// Explicit gemm conv
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -582,6 +896,131 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void explicit_gemm_conv_ND_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
|
const auto iDim = std::vector<int>(
|
||||||
|
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||||
|
const auto oDim = std::vector<int>(
|
||||||
|
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||||
|
const int O = wt.shape(0); // Out channels
|
||||||
|
const int C = wt.shape(-1); // In channels
|
||||||
|
const auto wDim = std::vector<int>(
|
||||||
|
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||||
|
|
||||||
|
auto conv_dtype = float32;
|
||||||
|
|
||||||
|
// Pad input
|
||||||
|
std::vector<int> padded_shape(in.shape().size());
|
||||||
|
padded_shape.front() = N;
|
||||||
|
for (size_t i = 0; i < iDim.size(); i++) {
|
||||||
|
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||||
|
}
|
||||||
|
padded_shape.back() = C;
|
||||||
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset = 0;
|
||||||
|
for (size_t i = 0; i < padding.size(); i++) {
|
||||||
|
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||||
|
}
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||||
|
|
||||||
|
// Make strided view
|
||||||
|
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||||
|
strided_shape.front() = N;
|
||||||
|
for (size_t i = 0; i < oDim.size(); i++) {
|
||||||
|
strided_shape[i + 1] = oDim[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < wDim.size(); i++) {
|
||||||
|
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||||
|
}
|
||||||
|
strided_shape.back() = C;
|
||||||
|
|
||||||
|
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||||
|
strided_strides[0] = in_padded.strides()[0];
|
||||||
|
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||||
|
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||||
|
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto flags = in_padded.flags();
|
||||||
|
|
||||||
|
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||||
|
in_strided_view.copy_shared_buffer(
|
||||||
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
|
// Materialize strided view
|
||||||
|
std::vector<int> strided_reshape = {N, C};
|
||||||
|
for (const auto& o : oDim) {
|
||||||
|
strided_reshape[0] *= o;
|
||||||
|
}
|
||||||
|
for (const auto& w : wDim) {
|
||||||
|
strided_reshape[1] *= w;
|
||||||
|
}
|
||||||
|
|
||||||
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
|
copy(in_strided_view, in_strided, CopyType::General);
|
||||||
|
|
||||||
|
// Check wt dtype and prepare
|
||||||
|
auto gemm_wt = wt;
|
||||||
|
auto gemm_out = out;
|
||||||
|
|
||||||
|
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||||
|
auto ctype =
|
||||||
|
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||||
|
copy(wt, gemm_wt, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform gemm
|
||||||
|
cblas_sgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
CblasNoTrans, // no trans A
|
||||||
|
CblasTrans, // transB
|
||||||
|
strided_reshape[0], // M
|
||||||
|
O, // N
|
||||||
|
strided_reshape[1], // K
|
||||||
|
1.0f, // alpha
|
||||||
|
in_strided.data<float>(),
|
||||||
|
strided_reshape[1], // lda
|
||||||
|
gemm_wt.data<float>(),
|
||||||
|
strided_reshape[1], // ldb
|
||||||
|
0.0f, // beta
|
||||||
|
gemm_out.data<float>(),
|
||||||
|
O // ldc
|
||||||
|
);
|
||||||
|
|
||||||
|
// Copy results if needed
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
copy(gemm_out, out, CopyType::Vector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Conv routing
|
// Conv routing
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -617,6 +1056,19 @@ void conv_2D_cpu(
|
|||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void conv_3D_cpu(
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
const std::vector<int>& in_dilation,
|
||||||
|
bool flip) {
|
||||||
|
return dispatch_slow_conv_3D(
|
||||||
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@ -625,8 +1077,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto& wt = inputs[1];
|
auto& wt = inputs[1];
|
||||||
|
|
||||||
|
// 3D convolution
|
||||||
|
if (in.ndim() == (3 + 2)) {
|
||||||
|
return conv_3D_cpu(
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_,
|
||||||
|
kernel_strides_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
flip_);
|
||||||
|
}
|
||||||
// 2D convolution
|
// 2D convolution
|
||||||
if (in.ndim() == (2 + 2)) {
|
else if (in.ndim() == (2 + 2)) {
|
||||||
return conv_2D_cpu(
|
return conv_2D_cpu(
|
||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
|
@ -759,6 +759,56 @@ void conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void conv_3D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
const std::vector<int>& in_dilation,
|
||||||
|
bool flip,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
// Make conv params
|
||||||
|
MLXConvParams<3> conv_params{
|
||||||
|
/* const int N = */ in.shape(0),
|
||||||
|
/* const int C = */ in.shape(4),
|
||||||
|
/* const int O = */ wt.shape(0),
|
||||||
|
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||||
|
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||||
|
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||||
|
/* const int kdil[NDIM] = */
|
||||||
|
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
|
||||||
|
/* const int idil[NDIM] = */
|
||||||
|
{in_dilation[0], in_dilation[1], in_dilation[2]},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0],
|
||||||
|
in.strides()[1],
|
||||||
|
in.strides()[2],
|
||||||
|
in.strides()[3],
|
||||||
|
in.strides()[4]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0],
|
||||||
|
wt.strides()[1],
|
||||||
|
wt.strides()[2],
|
||||||
|
wt.strides()[3],
|
||||||
|
wt.strides()[4]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0],
|
||||||
|
out.strides()[1],
|
||||||
|
out.strides()[2],
|
||||||
|
out.strides()[3],
|
||||||
|
out.strides()[4]},
|
||||||
|
/* const int groups = */ 1,
|
||||||
|
/* const bool flip = */ flip,
|
||||||
|
};
|
||||||
|
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -783,8 +833,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
wt = arr_copy;
|
wt = arr_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3D conv
|
||||||
|
if (out.ndim() == 5) {
|
||||||
|
conv_3D_gpu(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_,
|
||||||
|
kernel_strides_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_,
|
||||||
|
flip_,
|
||||||
|
copies);
|
||||||
|
}
|
||||||
// 2D conv
|
// 2D conv
|
||||||
if (out.ndim() == 4) {
|
else if (out.ndim() == 4) {
|
||||||
conv_2D_gpu(
|
conv_2D_gpu(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
|
34
mlx/ops.cpp
34
mlx/ops.cpp
@ -2885,7 +2885,7 @@ inline std::vector<int> conv_out_shape(
|
|||||||
|
|
||||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
|
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
|
||||||
<< spatial_dims << "D convolution.";
|
<< spatial_dims << "D convolution.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -3058,6 +3058,30 @@ array conv2d(
|
|||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 3D convolution with a filter */
|
||||||
|
array conv3d(
|
||||||
|
const array& in_,
|
||||||
|
const array& wt_,
|
||||||
|
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||||
|
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||||
|
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||||
|
int groups /* = 1 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return conv_general(
|
||||||
|
/* const array& input = */ in_,
|
||||||
|
/* const array& weight = */ wt_,
|
||||||
|
/* std::vector<int> stride = */
|
||||||
|
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||||
|
/* std::vector<int> padding = */
|
||||||
|
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||||
|
/* std::vector<int> kernel_dilation = */
|
||||||
|
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||||
|
/* std::vector<int> input_dilation = */ {1, 1, 1},
|
||||||
|
/* int groups = */ groups,
|
||||||
|
/* bool flip = */ false,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
/** General convolution with a filter */
|
/** General convolution with a filter */
|
||||||
array conv_general(
|
array conv_general(
|
||||||
array in,
|
array in,
|
||||||
@ -3078,9 +3102,9 @@ array conv_general(
|
|||||||
|
|
||||||
int spatial_dims = in.ndim() - 2;
|
int spatial_dims = in.ndim() - 2;
|
||||||
|
|
||||||
if (spatial_dims < 1 || spatial_dims > 2) {
|
if (spatial_dims < 1 || spatial_dims > 3) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
|
"[conv] Only works for inputs with 1-3 spatial dimensions."
|
||||||
" The inputs must be in the format [N, ..., C_in]");
|
" The inputs must be in the format [N, ..., C_in]");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3120,10 +3144,10 @@ array conv_general(
|
|||||||
// Check for negative padding
|
// Check for negative padding
|
||||||
bool has_neg_padding = false;
|
bool has_neg_padding = false;
|
||||||
for (auto& pd : padding_lo) {
|
for (auto& pd : padding_lo) {
|
||||||
has_neg_padding = (pd < 0);
|
has_neg_padding |= (pd < 0);
|
||||||
}
|
}
|
||||||
for (auto& pd : padding_hi) {
|
for (auto& pd : padding_hi) {
|
||||||
has_neg_padding = (pd < 0);
|
has_neg_padding |= (pd < 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle negative padding
|
// Handle negative padding
|
||||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -1120,6 +1120,16 @@ array conv2d(
|
|||||||
int groups = 1,
|
int groups = 1,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** 3D convolution with a filter */
|
||||||
|
array conv3d(
|
||||||
|
const array& input,
|
||||||
|
const array& weight,
|
||||||
|
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||||
|
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||||
|
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||||
|
int groups = 1,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||||
array quantized_matmul(
|
array quantized_matmul(
|
||||||
const array& x,
|
const array& x,
|
||||||
|
@ -894,20 +894,62 @@ std::vector<array> Convolution::vjp(
|
|||||||
padding_hi[i] = in_size - out_size + padding_[i];
|
padding_hi[i] = in_size - out_size + padding_[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for negative padding
|
||||||
|
bool has_neg_padding = false;
|
||||||
|
for (auto& pd : padding_lo) {
|
||||||
|
has_neg_padding |= (pd < 0);
|
||||||
|
}
|
||||||
|
for (auto& pd : padding_hi) {
|
||||||
|
has_neg_padding |= (pd < 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto padding_lo_ = std::vector<int>(padding_lo);
|
||||||
|
auto padding_hi_ = std::vector<int>(padding_hi);
|
||||||
|
|
||||||
|
// Use negative padding on the gradient output
|
||||||
|
if (has_neg_padding) {
|
||||||
|
for (auto& p : padding_lo_) {
|
||||||
|
p = std::max(0, p);
|
||||||
|
}
|
||||||
|
for (auto& p : padding_hi_) {
|
||||||
|
p = std::max(0, p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto wt_trans = swapaxes(wt, 0, -1, stream());
|
auto wt_trans = swapaxes(wt, 0, -1, stream());
|
||||||
|
|
||||||
auto grad = conv_general(
|
auto grad = conv_general(
|
||||||
/* const array& input = */ cotan,
|
/* const array& input = */ cotan,
|
||||||
/* const array& weight = */ wt_trans,
|
/* const array& weight = */ wt_trans,
|
||||||
/* std::vector<int> stride = */ input_dilation_,
|
/* std::vector<int> stride = */ input_dilation_,
|
||||||
/* std::vector<int> padding_lo = */ padding_lo,
|
/* std::vector<int> padding_lo = */ padding_lo_,
|
||||||
/* std::vector<int> padding_hi = */ padding_hi,
|
/* std::vector<int> padding_hi = */ padding_hi_,
|
||||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
||||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||||
/* int groups = */ 1,
|
/* int groups = */ 1,
|
||||||
/* bool flip = */ !flip_,
|
/* bool flip = */ !flip_,
|
||||||
stream());
|
stream());
|
||||||
|
|
||||||
|
// Handle negative padding
|
||||||
|
if (has_neg_padding) {
|
||||||
|
std::vector<int> starts(grad.ndim(), 0);
|
||||||
|
std::vector<int> stops = grad.shape();
|
||||||
|
|
||||||
|
for (int i = 0; i < grad.ndim() - 2; i++) {
|
||||||
|
if (padding_lo[i] < 0) {
|
||||||
|
starts[i + 1] -= padding_lo[i];
|
||||||
|
padding_lo[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (padding_hi[i] < 0) {
|
||||||
|
stops[i + 1] += padding_hi[i];
|
||||||
|
padding_hi[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
grad = slice(grad, std::move(starts), std::move(stops), stream());
|
||||||
|
}
|
||||||
|
|
||||||
grads.push_back(grad);
|
grads.push_back(grad);
|
||||||
}
|
}
|
||||||
// Grads for weight
|
// Grads for weight
|
||||||
|
@ -48,7 +48,7 @@ from mlx.nn.layers.activations import (
|
|||||||
)
|
)
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
|
||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||||
|
@ -132,3 +132,66 @@ class Conv2d(Module):
|
|||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class Conv3d(Module):
|
||||||
|
"""Applies a 3-dimensional convolution over the multi-channel input image.
|
||||||
|
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
|
||||||
|
- ``N`` is the batch dimension
|
||||||
|
- ``D`` is the input image depth
|
||||||
|
- ``H`` is the input image height
|
||||||
|
- ``W`` is the input image width
|
||||||
|
- ``C`` is the number of input channels
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
kernel_size (int or tuple): The size of the convolution filters.
|
||||||
|
stride (int or tuple, optional): The size of the stride when
|
||||||
|
applying the filter. Default: ``1``.
|
||||||
|
padding (int or tuple, optional): How many positions to 0-pad
|
||||||
|
the input with. Default: ``0``.
|
||||||
|
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||||
|
output. Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, tuple],
|
||||||
|
stride: Union[int, tuple] = 1,
|
||||||
|
padding: Union[int, tuple] = 0,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
kernel_size, stride, padding = map(
|
||||||
|
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||||
|
(kernel_size, stride, padding),
|
||||||
|
)
|
||||||
|
scale = math.sqrt(
|
||||||
|
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||||
|
)
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(out_channels, *kernel_size, in_channels),
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||||
|
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||||
|
f"padding={self.padding}, bias={'bias' in self}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
y = mx.conv3d(x, self.weight, self.stride, self.padding)
|
||||||
|
if "bias" in self:
|
||||||
|
y = y + self.bias
|
||||||
|
return y
|
||||||
|
@ -3230,6 +3230,78 @@ void init_ops(nb::module_& m) {
|
|||||||
array: The convolved array.
|
array: The convolved array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
|
"conv3d",
|
||||||
|
[](const array& input,
|
||||||
|
const array& weight,
|
||||||
|
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||||
|
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||||
|
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||||
|
int groups,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||||
|
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||||
|
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
|
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
|
} else {
|
||||||
|
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||||
|
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
|
} else {
|
||||||
|
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&dilation); pv) {
|
||||||
|
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
|
} else {
|
||||||
|
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||||
|
}
|
||||||
|
|
||||||
|
return conv3d(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
stride_tuple,
|
||||||
|
padding_tuple,
|
||||||
|
dilation_tuple,
|
||||||
|
groups,
|
||||||
|
s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::arg(),
|
||||||
|
"stride"_a = 1,
|
||||||
|
"padding"_a = 0,
|
||||||
|
"dilation"_a = 1,
|
||||||
|
"groups"_a = 1,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
3D convolution over an input with several channels
|
||||||
|
|
||||||
|
Note: Only the default ``groups=1`` is currently supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (array): input array of shape ``(N, D, H, W, C_in)``
|
||||||
|
weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
|
||||||
|
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
|
kernel strides. All spatial dimensions get the same stride if
|
||||||
|
only one number is specified. Default: ``1``.
|
||||||
|
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
|
symmetric input padding. All spatial dimensions get the same
|
||||||
|
padding if only one number is specified. Default: ``0``.
|
||||||
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
|
if only one number is specified. Default: ``1``
|
||||||
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The convolved array.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
"conv_general",
|
"conv_general",
|
||||||
[](const array& input,
|
[](const array& input,
|
||||||
const array& weight,
|
const array& weight,
|
||||||
|
@ -399,7 +399,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
[in_mx, wt_mx],
|
[in_mx, wt_mx],
|
||||||
[ct_mx],
|
[ct_mx],
|
||||||
)
|
)
|
||||||
pt_grad_in = F.grad.conv1d_input(
|
pt_grad_in = F.grad.conv2d_input(
|
||||||
in_pt.shape,
|
in_pt.shape,
|
||||||
wt_pt,
|
wt_pt,
|
||||||
ct_pt,
|
ct_pt,
|
||||||
@ -408,7 +408,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
)
|
||||||
pt_grad_wt = F.grad.conv1d_weight(
|
pt_grad_wt = F.grad.conv2d_weight(
|
||||||
in_pt,
|
in_pt,
|
||||||
wt_pt.shape,
|
wt_pt.shape,
|
||||||
ct_pt,
|
ct_pt,
|
||||||
@ -444,6 +444,203 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_3D(self):
|
||||||
|
def run_conv3D(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=(1, 1, 1),
|
||||||
|
groups=1,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
idim=idim,
|
||||||
|
kdim=kdim,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iD, iH, iW = idim
|
||||||
|
kD, kH, kW = kdim
|
||||||
|
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||||
|
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_pt, wt_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
|
||||||
|
(in_np, wt_np),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_mx = mx.conv3d(
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
out_pt = torch.conv3d(
|
||||||
|
in_pt,
|
||||||
|
wt_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||||
|
|
||||||
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in (
|
||||||
|
(1, 1, 1),
|
||||||
|
(1, 6, 1),
|
||||||
|
(1, 1, 6),
|
||||||
|
(4, 16, 32),
|
||||||
|
):
|
||||||
|
for idim, kdim, stride, padding in (
|
||||||
|
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||||
|
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||||
|
((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),
|
||||||
|
):
|
||||||
|
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_3D_grad(self):
|
||||||
|
def run_conv3D_grad(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation=(1, 1, 1),
|
||||||
|
groups=1,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
idim=idim,
|
||||||
|
kdim=kdim,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iD, iH, iW = idim
|
||||||
|
kD, kH, kW = kdim
|
||||||
|
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||||
|
|
||||||
|
oD = 1 + (
|
||||||
|
(iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0]
|
||||||
|
)
|
||||||
|
oH = 1 + (
|
||||||
|
(iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1]
|
||||||
|
)
|
||||||
|
oW = 1 + (
|
||||||
|
(iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||||
|
in_pt, wt_pt, ct_pt = map(
|
||||||
|
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
|
||||||
|
(in_np, wt_np, ct_np),
|
||||||
|
)
|
||||||
|
|
||||||
|
def f(a, b):
|
||||||
|
return mx.conv3d(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, outs_mx = mx.vjp(
|
||||||
|
f,
|
||||||
|
[in_mx, wt_mx],
|
||||||
|
[ct_mx],
|
||||||
|
)
|
||||||
|
pt_grad_in = F.grad.conv3d_input(
|
||||||
|
in_pt.shape,
|
||||||
|
wt_pt,
|
||||||
|
ct_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
pt_grad_wt = F.grad.conv3d_weight(
|
||||||
|
in_pt,
|
||||||
|
wt_pt.shape,
|
||||||
|
ct_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy()
|
||||||
|
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy()
|
||||||
|
|
||||||
|
mx_grad_in, mx_grad_wt = outs_mx
|
||||||
|
|
||||||
|
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||||
|
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||||
|
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||||
|
|
||||||
|
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||||
|
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||||
|
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)):
|
||||||
|
for idim, kdim, stride, padding, dilation in (
|
||||||
|
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||||
|
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||||
|
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||||
|
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||||
|
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||||
|
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||||
|
):
|
||||||
|
run_conv3D_grad(
|
||||||
|
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
def __conv_general_test(
|
def __conv_general_test(
|
||||||
self,
|
self,
|
||||||
in_shape,
|
in_shape,
|
||||||
|
Loading…
Reference in New Issue
Block a user