* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
This commit is contained in:
Max-Heinrich Laves
2024-05-11 15:15:02 +02:00
committed by GitHub
parent a9f80d60f6
commit ff4223904d
10 changed files with 951 additions and 13 deletions

View File

@@ -1120,6 +1120,16 @@ array conv2d(
int groups = 1,
StreamOrDevice s = {});
/** 3D convolution with a filter */
array conv3d(
const array& input,
const array& weight,
const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1},
int groups = 1,
StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul(
const array& x,