mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added conv_transpose fixes
This commit is contained in:
@@ -3911,4 +3911,64 @@ TEST_CASE("test bitwise shift operations") {
|
||||
|
||||
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
||||
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv_transpose1d with output_padding") {
|
||||
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
|
||||
auto wt = array({1.0, 1.0}, {1, 1, 2});
|
||||
int stride = 2;
|
||||
int padding = 0;
|
||||
int output_padding = 1;
|
||||
|
||||
auto out = conv_transpose1d(in, wt, stride, padding, 1, output_padding, 1);
|
||||
auto expected = array({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0}, {1, 1, 7});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv_transpose2d with output_padding") {
|
||||
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
|
||||
auto wt = array({1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2});
|
||||
std::pair<int, int> stride{2, 2};
|
||||
std::pair<int, int> padding{0, 0};
|
||||
std::pair<int, int> output_padding{1, 1};
|
||||
|
||||
auto out =
|
||||
conv_transpose2d(in, wt, stride, padding, {1, 1}, output_padding, 1);
|
||||
auto expected = array(
|
||||
{1.0,
|
||||
1.0,
|
||||
2.0,
|
||||
2.0,
|
||||
0.0,
|
||||
3.0,
|
||||
3.0,
|
||||
4.0,
|
||||
4.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0},
|
||||
{1, 1, 5, 5});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
|
||||
auto wt = array({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2, 2});
|
||||
std::tuple<int, int, int> stride{2, 2, 2};
|
||||
std::tuple<int, int, int> padding{0, 0, 0};
|
||||
std::tuple<int, int, int> output_padding{1, 1, 1};
|
||||
|
||||
auto out =
|
||||
conv_transpose3d(in, wt, stride, padding, {1, 1, 1}, output_padding, 1);
|
||||
auto expected = array(
|
||||
{1.0, 1.0, 2.0, 2.0, 0.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 5.0, 5.0, 6.0, 6.0, 0.0, 7.0, 7.0, 8.0, 8.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
|
||||
{1, 1, 5, 5, 5});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
Reference in New Issue
Block a user