Fix C++ conv tests

This commit is contained in:
Cheng
2025-07-18 01:29:38 -07:00
parent cea3af6622
commit 180ec0d3a5
3 changed files with 38 additions and 36 deletions

View File

@@ -29,19 +29,6 @@ using namespace cudnn_frontend;
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
}
auto swapaxes(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
std::vector<int64_t> shape(in.ndim());
std::vector<int64_t> strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
return std::make_tuple(shape, strides);
}
class Convolution {
public:
Convolution(
@@ -56,8 +43,7 @@ class Convolution {
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation,
int groups)
const std::vector<int64_t>& dilation)
: handle_(device.cudnn_handle()) {
auto cudnn_type = dtype_to_cudnn_type(dtype);
bool is_half = dtype == float16 || dtype == bfloat16;
@@ -98,7 +84,7 @@ class Convolution {
}
void run(
cu::CommandEncoder& encoder,
CommandEncoder& encoder,
const void* input,
const void* filter,
void* output) {
@@ -160,37 +146,57 @@ class Convolution {
} // namespace cu
namespace {
template <typename T, typename U>
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
auto nhwc_to_nchw(const array& in) {
auto shape = convert_vector<int64_t>(in.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(in.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides);
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Convolution::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
const array& input = inputs[0];
const array& filter = inputs[1];
const array& in = inputs[0];
const array& wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_input_array(wt);
encoder.set_output_array(out);
// cuDNN requires dims to be passed as NCHW.
int ndim = input.ndim();
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1);
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1);
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
auto [input_shape, input_strides] = nhwc_to_nchw(in);
auto [filter_shape, filter_strides] = nhwc_to_nchw(wt);
auto [output_shape, output_strides] = nhwc_to_nchw(out);
cu::Convolution conv(
cu::device(s.device),
input.dtype(),
in.dtype(),
input_shape,
input_strides,
filter_shape,
filter_strides,
output_shape,
output_strides,
std::vector<int64_t>(kernel_strides_.begin(), kernel_strides_.end()),
std::vector<int64_t>(padding_lo_.begin(), padding_lo_.end()),
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
groups_);
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
convert_vector<int64_t>(kernel_strides_),
convert_vector<int64_t>(padding_lo_),
convert_vector<int64_t>(padding_hi_),
convert_vector<int64_t>(kernel_dilation_));
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
}
} // namespace mlx::core

View File

@@ -16,7 +16,6 @@ cuda_skip = {
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",

View File

@@ -4,7 +4,6 @@
#define _USE_MATH_DEFINES
#include <cmath>
#include <iostream>
#include <numeric>
#include "doctest/doctest.h"
@@ -3643,7 +3642,6 @@ TEST_CASE("test conv1d") {
{1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}
@@ -3722,7 +3720,6 @@ TEST_CASE("test conv2d") {
auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>());
}