mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix cpp tests
This commit is contained in:
@@ -3915,60 +3915,66 @@ TEST_CASE("test bitwise shift operations") {
|
|||||||
|
|
||||||
TEST_CASE("test conv_transpose1d with output_padding") {
|
TEST_CASE("test conv_transpose1d with output_padding") {
|
||||||
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
|
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
|
||||||
auto wt = array({1.0, 1.0}, {1, 1, 2});
|
auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});
|
||||||
int stride = 2;
|
int stride = 2;
|
||||||
int padding = 0;
|
int padding = 0;
|
||||||
|
int dilation = 1;
|
||||||
int output_padding = 1;
|
int output_padding = 1;
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
auto out = conv_transpose1d(in, wt, stride, padding, 1, output_padding, 1);
|
auto out = conv_transpose1d(
|
||||||
auto expected = array({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0}, {1, 1, 7});
|
in, wt, stride, padding, dilation, output_padding, groups);
|
||||||
|
auto expected = array({6.0, 0.0}, {1, 2, 1});
|
||||||
CHECK(array_equal(out, expected).item<bool>());
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test conv_transpose2d with output_padding") {
|
TEST_CASE("test conv_transpose2d with output_padding") {
|
||||||
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
|
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
|
||||||
auto wt = array({1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2});
|
auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});
|
||||||
std::pair<int, int> stride{2, 2};
|
std::pair<int, int> stride{2, 2};
|
||||||
std::pair<int, int> padding{0, 0};
|
std::pair<int, int> padding{0, 0};
|
||||||
std::pair<int, int> output_padding{1, 1};
|
std::pair<int, int> output_padding{1, 1};
|
||||||
|
std::pair<int, int> dilation{1, 1};
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
auto out =
|
auto out = conv_transpose2d(
|
||||||
conv_transpose2d(in, wt, stride, padding, {1, 1}, output_padding, 1);
|
in, wt, stride, padding, dilation, output_padding, groups);
|
||||||
auto expected = array(
|
auto expected = array(
|
||||||
{1.0,
|
{3.0,
|
||||||
1.0,
|
3.0,
|
||||||
2.0,
|
0.0,
|
||||||
2.0,
|
0.0,
|
||||||
|
7.0,
|
||||||
|
7.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
3.0,
|
|
||||||
3.0,
|
|
||||||
4.0,
|
|
||||||
4.0,
|
|
||||||
0.0,
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
0.0,
|
0.0,
|
||||||
0.0},
|
0.0},
|
||||||
{1, 1, 5, 5});
|
{1, 2, 4, 2});
|
||||||
CHECK(array_equal(out, expected).item<bool>());
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test conv_transpose3d with output_padding") {
|
TEST_CASE("test conv_transpose3d with output_padding") {
|
||||||
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
|
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
|
||||||
auto wt = array({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2, 2});
|
auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});
|
||||||
std::tuple<int, int, int> stride{2, 2, 2};
|
std::tuple<int, int, int> stride{2, 2, 2};
|
||||||
std::tuple<int, int, int> padding{0, 0, 0};
|
std::tuple<int, int, int> padding{0, 0, 0};
|
||||||
std::tuple<int, int, int> output_padding{1, 1, 1};
|
std::tuple<int, int, int> output_padding{1, 1, 1};
|
||||||
|
std::tuple<int, int, int> dilation{1, 1, 1};
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
auto out =
|
auto out = conv_transpose3d(
|
||||||
conv_transpose3d(in, wt, stride, padding, {1, 1, 1}, output_padding, 1);
|
in, wt, stride, padding, dilation, output_padding, groups);
|
||||||
auto expected = array(
|
auto expected = array(
|
||||||
{1.0, 1.0, 2.0, 2.0, 0.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 0.0,
|
{3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,
|
||||||
0.0, 0.0, 0.0, 5.0, 5.0, 6.0, 6.0, 0.0, 7.0, 7.0, 8.0, 8.0,
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
{1, 2, 4, 4, 1});
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
|
|
||||||
{1, 1, 5, 5, 5});
|
|
||||||
CHECK(array_equal(out, expected).item<bool>());
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user