mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix C++ conv tests
This commit is contained in:
@@ -29,19 +29,6 @@ using namespace cudnn_frontend;
|
||||
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
||||
}
|
||||
|
||||
auto swapaxes(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
std::vector<int64_t> shape(in.ndim());
|
||||
std::vector<int64_t> strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
return std::make_tuple(shape, strides);
|
||||
}
|
||||
|
||||
class Convolution {
|
||||
public:
|
||||
Convolution(
|
||||
@@ -56,8 +43,7 @@ class Convolution {
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding_lo,
|
||||
const std::vector<int64_t>& padding_hi,
|
||||
const std::vector<int64_t>& dilation,
|
||||
int groups)
|
||||
const std::vector<int64_t>& dilation)
|
||||
: handle_(device.cudnn_handle()) {
|
||||
auto cudnn_type = dtype_to_cudnn_type(dtype);
|
||||
bool is_half = dtype == float16 || dtype == bfloat16;
|
||||
@@ -98,7 +84,7 @@ class Convolution {
|
||||
}
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
CommandEncoder& encoder,
|
||||
const void* input,
|
||||
const void* filter,
|
||||
void* output) {
|
||||
@@ -160,37 +146,57 @@ class Convolution {
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U>
|
||||
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& in) {
|
||||
auto shape = convert_vector<int64_t>(in.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
auto strides = convert_vector<int64_t>(in.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(shape, strides);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
const array& input = inputs[0];
|
||||
const array& filter = inputs[1];
|
||||
const array& in = inputs[0];
|
||||
const array& wt = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// cuDNN requires dims to be passed as NCHW.
|
||||
int ndim = input.ndim();
|
||||
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1);
|
||||
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1);
|
||||
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
|
||||
auto [input_shape, input_strides] = nhwc_to_nchw(in);
|
||||
auto [filter_shape, filter_strides] = nhwc_to_nchw(wt);
|
||||
auto [output_shape, output_strides] = nhwc_to_nchw(out);
|
||||
|
||||
cu::Convolution conv(
|
||||
cu::device(s.device),
|
||||
input.dtype(),
|
||||
in.dtype(),
|
||||
input_shape,
|
||||
input_strides,
|
||||
filter_shape,
|
||||
filter_strides,
|
||||
output_shape,
|
||||
output_strides,
|
||||
std::vector<int64_t>(kernel_strides_.begin(), kernel_strides_.end()),
|
||||
std::vector<int64_t>(padding_lo_.begin(), padding_lo_.end()),
|
||||
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
|
||||
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
|
||||
groups_);
|
||||
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
|
||||
convert_vector<int64_t>(kernel_strides_),
|
||||
convert_vector<int64_t>(padding_lo_),
|
||||
convert_vector<int64_t>(padding_hi_),
|
||||
convert_vector<int64_t>(kernel_dilation_));
|
||||
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -16,7 +16,6 @@ cuda_skip = {
|
||||
# Convolutions NYI
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -3643,7 +3642,6 @@ TEST_CASE("test conv1d") {
|
||||
{1, 3, 4});
|
||||
|
||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||
std::cout << out << std::endl;
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
@@ -3722,7 +3720,6 @@ TEST_CASE("test conv2d") {
|
||||
auto expected =
|
||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||
std::cout << out << std::endl;
|
||||
CHECK(allclose(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user