Fix C++ conv tests

This commit is contained in:
Cheng
2025-07-18 01:29:38 -07:00
parent cea3af6622
commit 180ec0d3a5
3 changed files with 38 additions and 36 deletions

View File

@@ -29,19 +29,6 @@ using namespace cudnn_frontend;
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \ throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
} }
auto swapaxes(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
std::vector<int64_t> shape(in.ndim());
std::vector<int64_t> strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
return std::make_tuple(shape, strides);
}
class Convolution { class Convolution {
public: public:
Convolution( Convolution(
@@ -56,8 +43,7 @@ class Convolution {
const std::vector<int64_t>& stride, const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo, const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi, const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation, const std::vector<int64_t>& dilation)
int groups)
: handle_(device.cudnn_handle()) { : handle_(device.cudnn_handle()) {
auto cudnn_type = dtype_to_cudnn_type(dtype); auto cudnn_type = dtype_to_cudnn_type(dtype);
bool is_half = dtype == float16 || dtype == bfloat16; bool is_half = dtype == float16 || dtype == bfloat16;
@@ -98,7 +84,7 @@ class Convolution {
} }
void run( void run(
cu::CommandEncoder& encoder, CommandEncoder& encoder,
const void* input, const void* input,
const void* filter, const void* filter,
void* output) { void* output) {
@@ -160,37 +146,57 @@ class Convolution {
} // namespace cu } // namespace cu
namespace {
template <typename T, typename U>
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
auto nhwc_to_nchw(const array& in) {
auto shape = convert_vector<int64_t>(in.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(in.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides);
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Convolution::eval_gpu"); nvtx3::scoped_range r("Convolution::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2); assert(inputs.size() == 2);
const array& input = inputs[0]; const array& in = inputs[0];
const array& filter = inputs[1]; const array& wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_input_array(wt);
encoder.set_output_array(out);
// cuDNN requires dims to be passed as NCHW. // cuDNN requires dims to be passed as NCHW.
int ndim = input.ndim(); auto [input_shape, input_strides] = nhwc_to_nchw(in);
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1); auto [filter_shape, filter_strides] = nhwc_to_nchw(wt);
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1); auto [output_shape, output_strides] = nhwc_to_nchw(out);
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
cu::Convolution conv( cu::Convolution conv(
cu::device(s.device), cu::device(s.device),
input.dtype(), in.dtype(),
input_shape, input_shape,
input_strides, input_strides,
filter_shape, filter_shape,
filter_strides, filter_strides,
output_shape, output_shape,
output_strides, output_strides,
std::vector<int64_t>(kernel_strides_.begin(), kernel_strides_.end()), convert_vector<int64_t>(kernel_strides_),
std::vector<int64_t>(padding_lo_.begin(), padding_lo_.end()), convert_vector<int64_t>(padding_lo_),
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()), convert_vector<int64_t>(padding_hi_),
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()), convert_vector<int64_t>(kernel_dilation_));
groups_); conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -16,7 +16,6 @@ cuda_skip = {
# Convolutions NYI # Convolutions NYI
"TestConv.test_1d_conv_with_2d", "TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding", "TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels", "TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad", "TestConv.test_conv_general_flip_grad",

View File

@@ -4,7 +4,6 @@
#define _USE_MATH_DEFINES #define _USE_MATH_DEFINES
#include <cmath> #include <cmath>
#include <iostream>
#include <numeric> #include <numeric>
#include "doctest/doctest.h" #include "doctest/doctest.h"
@@ -3643,7 +3642,6 @@ TEST_CASE("test conv1d") {
{1, 3, 4}); {1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>()); CHECK(allclose(out, expected).item<bool>());
} }
@@ -3722,7 +3720,6 @@ TEST_CASE("test conv2d") {
auto expected = auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4}); array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups); auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
std::cout << out << std::endl;
CHECK(allclose(out, expected).item<bool>()); CHECK(allclose(out, expected).item<bool>());
} }