mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:

committed by
GitHub

parent
a9f80d60f6
commit
ff4223904d
@@ -759,6 +759,56 @@ void conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void conv_3D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
|
||||
/* const int idil[NDIM] = */
|
||||
{in_dilation[0], in_dilation[1], in_dilation[2]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0],
|
||||
in.strides()[1],
|
||||
in.strides()[2],
|
||||
in.strides()[3],
|
||||
in.strides()[4]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0],
|
||||
wt.strides()[1],
|
||||
wt.strides()[2],
|
||||
wt.strides()[3],
|
||||
wt.strides()[4]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0],
|
||||
out.strides()[1],
|
||||
out.strides()[2],
|
||||
out.strides()[3],
|
||||
out.strides()[4]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -783,8 +833,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
// 3D conv
|
||||
if (out.ndim() == 5) {
|
||||
conv_3D_gpu(
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
else if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s,
|
||||
d,
|
||||
|
Reference in New Issue
Block a user