mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
committed by
GitHub
parent
ba3e913c7a
commit
efeb9c0f02
30
mlx/ops.h
30
mlx/ops.h
@@ -1247,6 +1247,36 @@ array conv3d(
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 1D transposed convolution with a filter */
|
||||
array conv_transpose1d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
int stride = 1,
|
||||
int padding = 0,
|
||||
int dilation = 1,
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 2D transposed convolution with a filter */
|
||||
array conv_transpose2d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
const std::pair<int, int>& stride = {1, 1},
|
||||
const std::pair<int, int>& padding = {0, 0},
|
||||
const std::pair<int, int>& dilation = {1, 1},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** 3D transposed convolution with a filter */
|
||||
array conv_transpose3d(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
|
||||
Reference in New Issue
Block a user