Transposed Convolution (#1245)

* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Max-Heinrich Laves
2024-09-07 04:52:38 +02:00
committed by GitHub
parent ba3e913c7a
commit efeb9c0f02
15 changed files with 1337 additions and 45 deletions

View File

@@ -1247,6 +1247,36 @@ array conv3d(
int groups = 1,
StreamOrDevice s = {});
/** 1D transposed convolution with a filter */
array conv_transpose1d(
const array& input,
const array& weight,
int stride = 1,
int padding = 0,
int dilation = 1,
int groups = 1,
StreamOrDevice s = {});
/** 2D transposed convolution with a filter */
array conv_transpose2d(
const array& input,
const array& weight,
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
int groups = 1,
StreamOrDevice s = {});
/** 3D transposed convolution with a filter */
array conv_transpose3d(
const array& input,
const array& weight,
const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1},
int groups = 1,
StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul(
const array& x,