Fix recording cudnn conv

This commit is contained in:
Cheng
2025-07-17 23:48:37 -07:00
parent 6571df6ad7
commit ae9dbb1a9b
4 changed files with 76 additions and 37 deletions

View File

@@ -4,6 +4,7 @@
#define _USE_MATH_DEFINES
#include <cmath>
#include <iostream>
#include <numeric>
#include "doctest/doctest.h"
@@ -3642,6 +3643,7 @@ TEST_CASE("test conv1d") {
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}
@@ -3720,6 +3722,7 @@ TEST_CASE("test conv2d") {
auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}