Fix C++ conv tests

This commit is contained in:
Cheng
2025-07-18 01:29:38 -07:00
parent cea3af6622
commit 180ec0d3a5
3 changed files with 38 additions and 36 deletions

View File

@@ -4,7 +4,6 @@
#define _USE_MATH_DEFINES
#include <cmath>
#include <iostream>
#include <numeric>
#include "doctest/doctest.h"
@@ -3643,7 +3642,6 @@ TEST_CASE("test conv1d") {
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}
@@ -3722,7 +3720,6 @@ TEST_CASE("test conv2d") {
auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}