mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added output_padding parameters in conv_transpose (#2092)
This commit is contained in:
@@ -3911,4 +3911,70 @@ TEST_CASE("test bitwise shift operations") {
|
||||
|
||||
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
||||
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test conv_transpose1d with output_padding") {
|
||||
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
|
||||
auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});
|
||||
int stride = 2;
|
||||
int padding = 0;
|
||||
int dilation = 1;
|
||||
int output_padding = 1;
|
||||
int groups = 1;
|
||||
|
||||
auto out = conv_transpose1d(
|
||||
in, wt, stride, padding, dilation, output_padding, groups);
|
||||
auto expected = array({6.0, 0.0}, {1, 2, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv_transpose2d with output_padding") {
|
||||
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
|
||||
auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});
|
||||
std::pair<int, int> stride{2, 2};
|
||||
std::pair<int, int> padding{0, 0};
|
||||
std::pair<int, int> output_padding{1, 1};
|
||||
std::pair<int, int> dilation{1, 1};
|
||||
int groups = 1;
|
||||
|
||||
auto out = conv_transpose2d(
|
||||
in, wt, stride, padding, dilation, output_padding, groups);
|
||||
auto expected = array(
|
||||
{3.0,
|
||||
3.0,
|
||||
0.0,
|
||||
0.0,
|
||||
7.0,
|
||||
7.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0},
|
||||
{1, 2, 4, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
|
||||
auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});
|
||||
std::tuple<int, int, int> stride{2, 2, 2};
|
||||
std::tuple<int, int, int> padding{0, 0, 0};
|
||||
std::tuple<int, int, int> output_padding{1, 1, 1};
|
||||
std::tuple<int, int, int> dilation{1, 1, 1};
|
||||
int groups = 1;
|
||||
|
||||
auto out = conv_transpose3d(
|
||||
in, wt, stride, padding, dilation, output_padding, groups);
|
||||
auto expected = array(
|
||||
{3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
|
||||
{1, 2, 4, 4, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user