mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix C++ conv tests
This commit is contained in:
@@ -29,19 +29,6 @@ using namespace cudnn_frontend;
|
|||||||
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
auto swapaxes(const array& in, int axis1, int axis2) {
|
|
||||||
std::vector<int> axes(in.ndim());
|
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
|
||||||
std::swap(axes[axis1], axes[axis2]);
|
|
||||||
std::vector<int64_t> shape(in.ndim());
|
|
||||||
std::vector<int64_t> strides(in.ndim());
|
|
||||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
|
||||||
shape[ax] = in.shape()[axes[ax]];
|
|
||||||
strides[ax] = in.strides()[axes[ax]];
|
|
||||||
}
|
|
||||||
return std::make_tuple(shape, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
class Convolution {
|
class Convolution {
|
||||||
public:
|
public:
|
||||||
Convolution(
|
Convolution(
|
||||||
@@ -56,8 +43,7 @@ class Convolution {
|
|||||||
const std::vector<int64_t>& stride,
|
const std::vector<int64_t>& stride,
|
||||||
const std::vector<int64_t>& padding_lo,
|
const std::vector<int64_t>& padding_lo,
|
||||||
const std::vector<int64_t>& padding_hi,
|
const std::vector<int64_t>& padding_hi,
|
||||||
const std::vector<int64_t>& dilation,
|
const std::vector<int64_t>& dilation)
|
||||||
int groups)
|
|
||||||
: handle_(device.cudnn_handle()) {
|
: handle_(device.cudnn_handle()) {
|
||||||
auto cudnn_type = dtype_to_cudnn_type(dtype);
|
auto cudnn_type = dtype_to_cudnn_type(dtype);
|
||||||
bool is_half = dtype == float16 || dtype == bfloat16;
|
bool is_half = dtype == float16 || dtype == bfloat16;
|
||||||
@@ -98,7 +84,7 @@ class Convolution {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
CommandEncoder& encoder,
|
||||||
const void* input,
|
const void* input,
|
||||||
const void* filter,
|
const void* filter,
|
||||||
void* output) {
|
void* output) {
|
||||||
@@ -160,37 +146,57 @@ class Convolution {
|
|||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||||
|
return std::vector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& in) {
|
||||||
|
auto shape = convert_vector<int64_t>(in.shape());
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
auto strides = convert_vector<int64_t>(in.strides());
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
auto& s = stream();
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const array& input = inputs[0];
|
const array& in = inputs[0];
|
||||||
const array& filter = inputs[1];
|
const array& wt = inputs[1];
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_input_array(wt);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
// cuDNN requires dims to be passed as NCHW.
|
// cuDNN requires dims to be passed as NCHW.
|
||||||
int ndim = input.ndim();
|
auto [input_shape, input_strides] = nhwc_to_nchw(in);
|
||||||
auto [input_shape, input_strides] = cu::swapaxes(input, 1, ndim - 1);
|
auto [filter_shape, filter_strides] = nhwc_to_nchw(wt);
|
||||||
auto [filter_shape, filter_strides] = cu::swapaxes(filter, 1, ndim - 1);
|
auto [output_shape, output_strides] = nhwc_to_nchw(out);
|
||||||
auto [output_shape, output_strides] = cu::swapaxes(out, 1, ndim - 1);
|
|
||||||
|
|
||||||
cu::Convolution conv(
|
cu::Convolution conv(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
input.dtype(),
|
in.dtype(),
|
||||||
input_shape,
|
input_shape,
|
||||||
input_strides,
|
input_strides,
|
||||||
filter_shape,
|
filter_shape,
|
||||||
filter_strides,
|
filter_strides,
|
||||||
output_shape,
|
output_shape,
|
||||||
output_strides,
|
output_strides,
|
||||||
std::vector<int64_t>(kernel_strides_.begin(), kernel_strides_.end()),
|
convert_vector<int64_t>(kernel_strides_),
|
||||||
std::vector<int64_t>(padding_lo_.begin(), padding_lo_.end()),
|
convert_vector<int64_t>(padding_lo_),
|
||||||
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
|
convert_vector<int64_t>(padding_hi_),
|
||||||
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
|
convert_vector<int64_t>(kernel_dilation_));
|
||||||
groups_);
|
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());
|
||||||
conv.run(encoder, input.data<void>(), filter.data<void>(), out.data<void>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ cuda_skip = {
|
|||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
"TestConv.test_asymmetric_padding",
|
||||||
"TestConv.test_basic_grad_shapes",
|
|
||||||
"TestConv.test_conv2d_unaligned_channels",
|
"TestConv.test_conv2d_unaligned_channels",
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
@@ -3643,7 +3642,6 @@ TEST_CASE("test conv1d") {
|
|||||||
{1, 3, 4});
|
{1, 3, 4});
|
||||||
|
|
||||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||||
std::cout << out << std::endl;
|
|
||||||
CHECK(allclose(out, expected).item<bool>());
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3722,7 +3720,6 @@ TEST_CASE("test conv2d") {
|
|||||||
auto expected =
|
auto expected =
|
||||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||||
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||||
std::cout << out << std::endl;
|
|
||||||
CHECK(allclose(out, expected).item<bool>());
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user