mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
[CUDA] Backward convolution (#2431)
This commit is contained in:
parent
8b25ce62d5
commit
86c6a15571
@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
|||||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = x.shape();
|
||||||
|
std::swap(shape[axis1], shape[axis2]);
|
||||||
|
auto strides = x.strides();
|
||||||
|
std::swap(strides[axis1], strides[axis2]);
|
||||||
|
|
||||||
|
auto [data_size, row_contiguous, col_contiguous] =
|
||||||
|
check_contiguity(shape, strides);
|
||||||
|
bool contiguous = data_size == x.data_size();
|
||||||
|
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
x,
|
||||||
|
std::move(strides),
|
||||||
|
{contiguous, row_contiguous, col_contiguous},
|
||||||
|
x.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -196,6 +196,9 @@ void shared_buffer_reshape(
|
|||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
|
// Like the swapaxes op but safe to call in eval_gpu.
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||||
vec.erase(std::next(vec.begin(), index));
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -25,25 +24,34 @@ namespace {
|
|||||||
// Not all engines support it so can not use this API now.
|
// Not all engines support it so can not use this API now.
|
||||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||||
|
|
||||||
|
// Alias for better readability.
|
||||||
|
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||||
|
#define CONV_BACKWARD_INPUT \
|
||||||
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
||||||
|
#define CONV_BACKWARD_WEIGHT \
|
||||||
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||||
|
|
||||||
struct ConvCacheKey {
|
struct ConvCacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
cudnnBackendDescriptorType_t backend_type;
|
cudnnDataType_t cudnn_dtype;
|
||||||
cudnnDataType_t cudnn_type;
|
|
||||||
std::array<int, MAX_NDIM> input_shape;
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
std::array<int, MAX_NDIM> filter_shape;
|
std::array<int, MAX_NDIM> weight_shape;
|
||||||
|
std::array<int, MAX_NDIM> stride;
|
||||||
std::array<int, MAX_NDIM> padding_lo;
|
std::array<int, MAX_NDIM> padding_lo;
|
||||||
std::array<int, MAX_NDIM> padding_hi;
|
std::array<int, MAX_NDIM> padding_hi;
|
||||||
std::array<int, MAX_NDIM> stride;
|
|
||||||
std::array<int, MAX_NDIM> dilation;
|
std::array<int, MAX_NDIM> dilation;
|
||||||
int groups;
|
int groups;
|
||||||
|
bool flip;
|
||||||
uint8_t input_alignment;
|
uint8_t input_alignment;
|
||||||
uint8_t filter_alignment;
|
uint8_t weight_alignment;
|
||||||
uint8_t output_alignment;
|
uint8_t output_alignment;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto& conv_cache() {
|
auto& conv_cache() {
|
||||||
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
static LRUBytesKeyCache<
|
||||||
/* capacity */ 128);
|
ConvCacheKey,
|
||||||
|
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
||||||
|
cache(/* capacity */ 128);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,17 +170,17 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
|||||||
bool execute_plan(
|
bool execute_plan(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
const array& in,
|
array& x,
|
||||||
const array& wt,
|
array& w,
|
||||||
array& out) {
|
array& y) {
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||||
|
|
||||||
int64_t uids[3] = {'x', 'w', 'y'};
|
int64_t uids[3] = {'x', 'w', 'y'};
|
||||||
void* data_ptrs[3] = {
|
void* data_ptrs[3] = {
|
||||||
const_cast<void*>(in.data<void>()),
|
x.data<void>(),
|
||||||
const_cast<void*>(wt.data<void>()),
|
w.data<void>(),
|
||||||
out.data<void>(),
|
y.data<void>(),
|
||||||
};
|
};
|
||||||
|
|
||||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||||
@ -212,46 +220,154 @@ bool execute_plan(
|
|||||||
|
|
||||||
bool try_engines(
|
bool try_engines(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnn_frontend::EngineConfigList& configs,
|
|
||||||
const ConvCacheKey& cache_key,
|
const ConvCacheKey& cache_key,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
cudnn_frontend::EngineConfigList& configs,
|
||||||
const std::string& op_graph_tag,
|
const std::string& op_graph_tag,
|
||||||
const array& in,
|
array& x,
|
||||||
const array& wt,
|
array& w,
|
||||||
array& out) {
|
array& y) {
|
||||||
for (auto& config : configs) {
|
for (auto& config : configs) {
|
||||||
try {
|
try {
|
||||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||||
.setHandle(encoder.device().cudnn_handle())
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
.setEngineConfig(config, op_graph_tag)
|
.setEngineConfig(config, op_graph_tag)
|
||||||
.build();
|
.build();
|
||||||
if (execute_plan(encoder, plan, in, wt, out)) {
|
if (execute_plan(encoder, plan, x, w, y)) {
|
||||||
conv_cache().emplace(cache_key, std::move(plan));
|
conv_cache().emplace(
|
||||||
|
cache_key, std::make_pair(backend_type, std::move(plan)));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} catch (cudnn_frontend::cudnnException&) {
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
auto get_conv_op_settings(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y,
|
||||||
|
const std::vector<int>& kernel_strides,
|
||||||
|
const std::vector<int>& padding_lo_,
|
||||||
|
const std::vector<int>& padding_hi_,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation) {
|
||||||
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||||
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||||
|
|
||||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||||
if (out.size() == 0) {
|
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||||
return;
|
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||||
|
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
||||||
|
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
||||||
|
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(input_dilation),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_dilation));
|
||||||
|
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
padding_hi = padding_lo;
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(kernel_dilation),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_strides));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(kernel_strides),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_dilation));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y,
|
||||||
|
const std::vector<int64_t>& stride,
|
||||||
|
const std::vector<int64_t>& padding_lo,
|
||||||
|
const std::vector<int64_t>& padding_hi,
|
||||||
|
const std::vector<int64_t>& dilation) {
|
||||||
|
try {
|
||||||
|
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||||
|
? CUDNN_DATA_FLOAT
|
||||||
|
: dtype_to_cudnn_type(dtype);
|
||||||
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||||
|
.setDataType(compute_dtype)
|
||||||
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||||
|
.setNDims(stride.size())
|
||||||
|
.setStrides(stride.size(), stride.data())
|
||||||
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||||
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||||
|
.setDilation(dilation.size(), dilation.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
|
.setxDesc(build_tensor('x', x))
|
||||||
|
.setwDesc(build_tensor('w', w))
|
||||||
|
.setyDesc(build_tensor('y', y))
|
||||||
|
.setcDesc(conv_desc)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||||
|
return cudnn_frontend::OperationGraphBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setOperationGraph(ops.size(), ops.data())
|
||||||
|
.build();
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||||
|
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||||
|
// eval_gpu, with cost of possible redundant copies.
|
||||||
|
std::tuple<array, array, array> prepare_args(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array in,
|
||||||
|
array wt,
|
||||||
|
array out,
|
||||||
|
Stream s) {
|
||||||
|
// Transpose the args depending on the backend type.
|
||||||
|
// TODO: Handle groups.
|
||||||
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
|
wt = swapaxes_in_eval(wt, 0, -1);
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
in = swapaxes_in_eval(in, 0, -1);
|
||||||
|
wt = swapaxes_in_eval(wt, 0, -1);
|
||||||
|
// Create a contiguous array that shares the data with |out|, but with dim
|
||||||
|
// C_in and C_out swapped.
|
||||||
|
Shape shape(out.shape());
|
||||||
|
std::swap(shape.front(), shape.back());
|
||||||
|
Strides strides(shape.size(), 1);
|
||||||
|
for (int i = shape.size() - 2; i >= 0; --i) {
|
||||||
|
strides[i] = shape[i + 1] * strides[i + 1];
|
||||||
|
}
|
||||||
|
array intermediate(std::move(shape), out.dtype(), nullptr, {});
|
||||||
|
intermediate.copy_shared_buffer(
|
||||||
|
out, std::move(strides), {true, true, false}, out.data_size());
|
||||||
|
out = intermediate;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
array in = inputs[0];
|
|
||||||
array wt = inputs[1];
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
|
|
||||||
// cuDNN requires contiguous input.
|
// cuDNN requires contiguous input.
|
||||||
// TODO: Handle NCHW format specially.
|
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
in = contiguous_copy_gpu(in, s);
|
in = contiguous_copy_gpu(in, s);
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
@ -261,80 +377,170 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.add_temporary(wt);
|
encoder.add_temporary(wt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return {std::move(in), std::move(wt), std::move(out)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
||||||
|
inline std::tuple<array&, array&, array&> dispatch_args(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& in,
|
||||||
|
array& wt,
|
||||||
|
array& out) {
|
||||||
|
switch (backend_type) {
|
||||||
|
case CONV_BACKWARD_INPUT:
|
||||||
|
return {out, wt, in};
|
||||||
|
case CONV_BACKWARD_WEIGHT:
|
||||||
|
return {in, out, wt};
|
||||||
|
default:
|
||||||
|
return {in, wt, out};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register inputs and outputs before actually running conv op. Can only be
|
||||||
|
// called once per eval_gpu.
|
||||||
|
void register_args(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& in,
|
||||||
|
array& wt,
|
||||||
|
array& intermediate_out,
|
||||||
|
array& final_out) {
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_input_array(wt);
|
encoder.set_input_array(wt);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(final_out);
|
||||||
|
|
||||||
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
// Turn |out| into a strided array, which will have C_in and C_out swapped
|
||||||
|
// in vjp and the final |grad_weight| will then be contiguous.
|
||||||
|
Strides strides = intermediate_out.strides();
|
||||||
|
std::swap(strides.front(), strides.back());
|
||||||
|
final_out.copy_shared_buffer(
|
||||||
|
intermediate_out,
|
||||||
|
std::move(strides),
|
||||||
|
{false, false, false},
|
||||||
|
intermediate_out.data_size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||||
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
|
if (out_.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
array in = inputs[0];
|
||||||
|
array wt = inputs[1];
|
||||||
|
array out = out_;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
Dtype dtype = out.dtype();
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
ConvCacheKey cache_key{
|
ConvCacheKey cache_key{
|
||||||
encoder.device().cuda_device(),
|
encoder.device().cuda_device(),
|
||||||
backend_type,
|
dtype_to_cudnn_type(dtype),
|
||||||
cudnn_type,
|
|
||||||
fixed_vector(in.shape()),
|
fixed_vector(in.shape()),
|
||||||
fixed_vector(wt.shape()),
|
fixed_vector(wt.shape()),
|
||||||
|
fixed_vector(kernel_strides_),
|
||||||
fixed_vector(padding_lo_),
|
fixed_vector(padding_lo_),
|
||||||
fixed_vector(padding_hi_),
|
fixed_vector(padding_hi_),
|
||||||
fixed_vector(kernel_strides_),
|
|
||||||
fixed_vector(kernel_dilation_),
|
fixed_vector(kernel_dilation_),
|
||||||
groups_,
|
groups_,
|
||||||
|
flip_,
|
||||||
get_alignment(in),
|
get_alignment(in),
|
||||||
get_alignment(wt),
|
get_alignment(wt),
|
||||||
get_alignment(out)};
|
get_alignment(out)};
|
||||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
auto& [backend_type, plan] = it->second;
|
||||||
throw std::runtime_error("Cached convolution plan failed to execute.");
|
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||||
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||||
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build operation graph.
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||||
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
// convolution, so we make a best guess and then try.
|
||||||
? CUDNN_DATA_FLOAT
|
std::vector<cudnnBackendDescriptorType_t> try_backends;
|
||||||
: cudnn_type;
|
if (flip_) {
|
||||||
|
// When weight is flipped, we assume it is backward input convolution.
|
||||||
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||||
|
} else {
|
||||||
|
// Otherwise it could be backward weight convolution or forward convolution,
|
||||||
|
// mathematically there is no difference so we have to use heuristics.
|
||||||
|
// Empirically backward convolutions have large kernel dimensions, and
|
||||||
|
// usually have |in| and |wt| transposed.
|
||||||
|
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
|
||||||
|
wt.shape(2) > out.shape(2)) {
|
||||||
|
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
|
||||||
|
} else {
|
||||||
|
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto stride = convert_vector<int64_t>(kernel_strides_);
|
// Try to build op graph.
|
||||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
cudnnBackendDescriptorType_t backend_type;
|
||||||
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||||
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
for (auto try_backend : try_backends) {
|
||||||
|
auto [in_copy, wt_copy, out_copy] =
|
||||||
|
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||||
|
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||||
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||||
|
try_backend,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
y,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_);
|
||||||
|
op_graph = build_op_graph(
|
||||||
|
encoder,
|
||||||
|
try_backend,
|
||||||
|
dtype,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
y,
|
||||||
|
stride,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
dilation);
|
||||||
|
if (op_graph) {
|
||||||
|
backend_type = try_backend;
|
||||||
|
in = std::move(in_copy);
|
||||||
|
wt = std::move(wt_copy);
|
||||||
|
out = std::move(out_copy);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!op_graph) {
|
||||||
|
throw std::runtime_error("[conv] Can not build op graph.");
|
||||||
|
}
|
||||||
|
|
||||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
// Get ready to execute the graph.
|
||||||
.setDataType(compute_data_type)
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
|
||||||
.setNDims(stride.size())
|
|
||||||
.setStrides(stride.size(), stride.data())
|
|
||||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
|
||||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
|
||||||
.setDilation(dilation.size(), dilation.data())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
|
||||||
.setxDesc(build_tensor('x', in))
|
|
||||||
.setwDesc(build_tensor('w', wt))
|
|
||||||
.setyDesc(build_tensor('y', out))
|
|
||||||
.setcDesc(conv_desc)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
|
||||||
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
|
||||||
.setHandle(encoder.device().cudnn_handle())
|
|
||||||
.setOperationGraph(ops.size(), ops.data())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
// Try to run plans based on heuristics.
|
// Try to run plans based on heuristics.
|
||||||
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||||
auto op_graph_tag = op_graph.getTag();
|
auto tag = op_graph->getTag();
|
||||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Then try fallback plans.
|
// Then try fallback plans.
|
||||||
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Unable to find an engine for convolution.");
|
throw std::runtime_error("[conv] Unable to find a working engine.");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
#include <cub/device/device_segmented_sort.cuh>
|
#include <cub/device/device_segmented_sort.cuh>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -27,29 +26,6 @@ struct ModOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// We can not use any op in eval, make an utility.
|
|
||||||
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
|
||||||
std::vector<int> axes(in.ndim());
|
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
|
||||||
std::swap(axes[axis1], axes[axis2]);
|
|
||||||
// TODO: Share the code with Transpose::eval.
|
|
||||||
Shape shape(axes.size());
|
|
||||||
Strides strides(in.ndim());
|
|
||||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
|
||||||
shape[ax] = in.shape()[axes[ax]];
|
|
||||||
strides[ax] = in.strides()[axes[ax]];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (flags.contiguous) {
|
|
||||||
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
|
||||||
flags.row_contiguous = row_contiguous;
|
|
||||||
flags.col_contiguous = col_contiguous;
|
|
||||||
}
|
|
||||||
array out(shape, in.dtype(), nullptr, {});
|
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct OffsetTransform {
|
struct OffsetTransform {
|
||||||
int nsort;
|
int nsort;
|
||||||
|
|
||||||
|
@ -18,21 +18,12 @@ cuda_skip = {
|
|||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
"TestConv.test_conv_groups_grad",
|
"TestConv.test_conv_groups_grad",
|
||||||
"TestConv.test_torch_conv_1D_grad",
|
|
||||||
"TestConv.test_torch_conv_2D",
|
"TestConv.test_torch_conv_2D",
|
||||||
"TestConv.test_torch_conv_2D_grad",
|
|
||||||
"TestConv.test_torch_conv_3D_grad",
|
|
||||||
"TestConv.test_torch_conv_depthwise",
|
"TestConv.test_torch_conv_depthwise",
|
||||||
"TestConv.test_torch_conv_general",
|
"TestConv.test_torch_conv_general",
|
||||||
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_1D",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_2D",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
Loading…
Reference in New Issue
Block a user