mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding * fix bugs * add test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
|
||||
explicit Convolution(
|
||||
Stream stream,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
const int groups = 1,
|
||||
const bool flip = false)
|
||||
: UnaryPrimitive(stream),
|
||||
padding_(padding),
|
||||
padding_lo_(padding_lo),
|
||||
padding_hi_(padding_hi),
|
||||
kernel_strides_(kernel_strides),
|
||||
kernel_dilation_(kernel_dilation),
|
||||
input_dilation_(input_dilation),
|
||||
@@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> padding_;
|
||||
std::vector<int> padding_lo_;
|
||||
std::vector<int> padding_hi_;
|
||||
std::vector<int> kernel_strides_;
|
||||
std::vector<int> kernel_dilation_;
|
||||
std::vector<int> input_dilation_;
|
||||
|
||||
Reference in New Issue
Block a user