mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix recording cudnn conv
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -3642,6 +3643,7 @@ TEST_CASE("test conv1d") {
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
std::cout << out << std::endl;
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
@@ -3720,6 +3722,7 @@ TEST_CASE("test conv2d") {
|
||||
auto expected =
|
||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||
std::cout << out << std::endl;
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user