fix: conv_general differences between gpu, cpu (#2070)

* fix general_conv padding

* fix bugs

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
ATurker
2025-05-09 20:26:52 +03:00
committed by GitHub
parent 0cae0bdac8
commit a7fae8a176
6 changed files with 413 additions and 270 deletions

View File

@@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
explicit Convolution(
Stream stream,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
const int groups = 1,
const bool flip = false)
: UnaryPrimitive(stream),
padding_(padding),
padding_lo_(padding_lo),
padding_hi_(padding_hi),
kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation),
@@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(
padding_,
padding_lo_,
padding_hi_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
@@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
}
private:
std::vector<int> padding_;
std::vector<int> padding_lo_;
std::vector<int> padding_hi_;
std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_;