* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
This commit is contained in:
Max-Heinrich Laves 2024-05-11 15:15:02 +02:00 committed by GitHub
parent a9f80d60f6
commit ff4223904d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 951 additions and 13 deletions

View File

@ -15,6 +15,7 @@ Layers
BatchNorm
Conv1d
Conv2d
Conv3d
Dropout
Dropout2d
Dropout3d

View File

@ -310,6 +310,296 @@ void slow_conv_2D(
} // n
}
template <typename T>
void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
const int oD = out.shape(1); // Output spatial dim
const int oH = out.shape(2); // Output spatial dim
const int oW = out.shape(3); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(4); // In channels
const int wD = wt.shape(1); // Weight spatial dim
const int wH = wt.shape(2); // Weight spatial dim
const int wW = wt.shape(3); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_D = in.strides()[1];
const size_t in_stride_H = in.strides()[2];
const size_t in_stride_W = in.strides()[3];
const size_t in_stride_C = in.strides()[4];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_D = wt.strides()[1];
const size_t wt_stride_H = wt.strides()[2];
const size_t wt_stride_W = wt.strides()[3];
const size_t wt_stride_C = wt.strides()[4];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_D = out.strides()[1];
const size_t out_stride_H = out.strides()[2];
const size_t out_stride_W = out.strides()[3];
const size_t out_stride_O = out.strides()[4];
bool is_idil_one =
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = 0; wd < wD; ++wd) {
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
const T* wt_ptr_pt =
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
std::vector<int> base_d(f_out_jump_d);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
wd_base++;
id_loop += jump_d;
}
base_d[i] = wd_base;
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
iw < iW) {
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // iD, ih, iw check
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int oD_border_0 = 0;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: od might put us out of bounds
for (int od = oD_border_0; od < oD_border_1; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 2: od in bounds
for (int od = oD_border_1; od < oD_border_2; ++od) {
// Case 2.1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case 2.2.1: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.2: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.3: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 3: od might put us out of bounds
for (int od = oD_border_2; od < oD_border_3; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
}
void dispatch_slow_conv_1D(
const array& in,
const array& wt,
@ -358,6 +648,30 @@ void dispatch_slow_conv_2D(
}
}
void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_3D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_3D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_3D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
@ -582,6 +896,131 @@ void explicit_gemm_conv_2D_cpu(
}
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
}
for (size_t i = 0; i < wDim.size(); i++) {
strided_shape[i + 1 + oDim.size()] = wDim[i];
}
strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
}
for (size_t i = 1; i < in_padded.strides().size(); i++) {
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
}
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}
for (const auto& w : wDim) {
strided_reshape[1] *= w;
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////
@ -617,6 +1056,19 @@ void conv_2D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
void Convolution::eval(const std::vector<array>& inputs, array& out) {
@ -625,8 +1077,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto& wt = inputs[1];
// 3D convolution
if (in.ndim() == (3 + 2)) {
return conv_3D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 2D convolution
if (in.ndim() == (2 + 2)) {
else if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,

View File

@ -759,6 +759,56 @@ void conv_2D_gpu(
}
}
void conv_3D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
/* const int idil[NDIM] = */
{in_dilation[0], in_dilation[1], in_dilation[2]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0],
in.strides()[1],
in.strides()[2],
in.strides()[3],
in.strides()[4]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0],
wt.strides()[1],
wt.strides()[2],
wt.strides()[3],
wt.strides()[4]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0],
out.strides()[1],
out.strides()[2],
out.strides()[3],
out.strides()[4]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -783,8 +833,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy;
}
// 3D conv
if (out.ndim() == 5) {
conv_3D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
}
// 2D conv
if (out.ndim() == 4) {
else if (out.ndim() == 4) {
conv_2D_gpu(
s,
d,

View File

@ -2878,14 +2878,14 @@ inline std::vector<int> conv_out_shape(
if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
@ -3058,6 +3058,30 @@ array conv2d(
s);
}
/** 3D convolution with a filter */
array conv3d(
const array& in_,
const array& wt_,
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_general(
/* const array& input = */ in_,
/* const array& weight = */ wt_,
/* std::vector<int> stride = */
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
/* std::vector<int> padding = */
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
/* std::vector<int> kernel_dilation = */
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
/* std::vector<int> input_dilation = */ {1, 1, 1},
/* int groups = */ groups,
/* bool flip = */ false,
s);
}
/** General convolution with a filter */
array conv_general(
array in,
@ -3078,9 +3102,9 @@ array conv_general(
int spatial_dims = in.ndim() - 2;
if (spatial_dims < 1 || spatial_dims > 2) {
if (spatial_dims < 1 || spatial_dims > 3) {
throw std::invalid_argument(
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
"[conv] Only works for inputs with 1-3 spatial dimensions."
" The inputs must be in the format [N, ..., C_in]");
}
@ -3120,10 +3144,10 @@ array conv_general(
// Check for negative padding
bool has_neg_padding = false;
for (auto& pd : padding_lo) {
has_neg_padding = (pd < 0);
has_neg_padding |= (pd < 0);
}
for (auto& pd : padding_hi) {
has_neg_padding = (pd < 0);
has_neg_padding |= (pd < 0);
}
// Handle negative padding

View File

@ -1120,6 +1120,16 @@ array conv2d(
int groups = 1,
StreamOrDevice s = {});
/** 3D convolution with a filter */
array conv3d(
const array& input,
const array& weight,
const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1},
int groups = 1,
StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul(
const array& x,

View File

@ -894,20 +894,62 @@ std::vector<array> Convolution::vjp(
padding_hi[i] = in_size - out_size + padding_[i];
}
// Check for negative padding
bool has_neg_padding = false;
for (auto& pd : padding_lo) {
has_neg_padding |= (pd < 0);
}
for (auto& pd : padding_hi) {
has_neg_padding |= (pd < 0);
}
auto padding_lo_ = std::vector<int>(padding_lo);
auto padding_hi_ = std::vector<int>(padding_hi);
// Use negative padding on the gradient output
if (has_neg_padding) {
for (auto& p : padding_lo_) {
p = std::max(0, p);
}
for (auto& p : padding_hi_) {
p = std::max(0, p);
}
}
auto wt_trans = swapaxes(wt, 0, -1, stream());
auto grad = conv_general(
/* const array& input = */ cotan,
/* const array& weight = */ wt_trans,
/* std::vector<int> stride = */ input_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> padding_lo = */ padding_lo_,
/* std::vector<int> padding_hi = */ padding_hi_,
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1,
/* bool flip = */ !flip_,
stream());
// Handle negative padding
if (has_neg_padding) {
std::vector<int> starts(grad.ndim(), 0);
std::vector<int> stops = grad.shape();
for (int i = 0; i < grad.ndim() - 2; i++) {
if (padding_lo[i] < 0) {
starts[i + 1] -= padding_lo[i];
padding_lo[i] = 0;
}
if (padding_hi[i] < 0) {
stops[i + 1] += padding_hi[i];
padding_hi[i] = 0;
}
}
grad = slice(grad, std::move(starts), std::move(stops), stream());
}
grads.push_back(grad);
}
// Grads for weight

View File

@ -48,7 +48,7 @@ from mlx.nn.layers.activations import (
)
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear

View File

@ -132,3 +132,66 @@ class Conv2d(Module):
if "bias" in self:
y = y + self.bias
return y
class Conv3d(Module):
"""Applies a 3-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:
- ``N`` is the batch dimension
- ``D`` is the input image depth
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: ``1``.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
f"padding={self.padding}, bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv3d(x, self.weight, self.stride, self.padding)
if "bias" in self:
y = y + self.bias
return y

View File

@ -3230,6 +3230,78 @@ void init_ops(nb::module_& m) {
array: The convolved array.
)pbdoc");
m.def(
"conv3d",
[](const array& input,
const array& weight,
const std::variant<int, std::tuple<int, int, int>>& stride,
const std::variant<int, std::tuple<int, int, int>>& padding,
const std::variant<int, std::tuple<int, int, int>>& dilation,
int groups,
StreamOrDevice s) {
std::tuple<int, int, int> stride_tuple{1, 1, 1};
std::tuple<int, int, int> padding_tuple{0, 0, 0};
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
}
if (auto pv = std::get_if<int>(&dilation); pv) {
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
}
return conv3d(
input,
weight,
stride_tuple,
padding_tuple,
dilation_tuple,
groups,
s);
},
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
3D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): input array of shape ``(N, D, H, W, C_in)``
weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel strides. All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
symmetric input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): input feature groups. Default: ``1``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"conv_general",
[](const array& input,
const array& weight,

View File

@ -399,7 +399,7 @@ class TestConv(mlx_tests.MLXTestCase):
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv1d_input(
pt_grad_in = F.grad.conv2d_input(
in_pt.shape,
wt_pt,
ct_pt,
@ -408,7 +408,7 @@ class TestConv(mlx_tests.MLXTestCase):
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv1d_weight(
pt_grad_wt = F.grad.conv2d_weight(
in_pt,
wt_pt.shape,
ct_pt,
@ -444,6 +444,203 @@ class TestConv(mlx_tests.MLXTestCase):
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_3D(self):
def run_conv3D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 16, 32),
):
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),
):
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_3D_grad(self):
def run_conv3D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C)
oD = 1 + (
(iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0]
)
oH = 1 + (
(iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1]
)
oW = 1 + (
(iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2]
)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
np_dtype
)
ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype(
np_dtype
)
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
in_pt, wt_pt, ct_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
(in_np, wt_np, ct_np),
)
def f(a, b):
return mx.conv3d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv3d_input(
in_pt.shape,
wt_pt,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv3d_weight(
in_pt,
wt_pt.shape,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy()
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy()
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)):
for idim, kdim, stride, padding, dilation in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
):
run_conv3D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
def __conv_general_test(
self,
in_shape,