|
|
|
@@ -16,7 +16,6 @@
|
|
|
|
|
#include <nvtx3/nvtx3.hpp>
|
|
|
|
|
|
|
|
|
|
#include <cassert>
|
|
|
|
|
#include <numeric>
|
|
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
|
|
|
|
@@ -25,25 +24,34 @@ namespace {
|
|
|
|
|
// Not all engines support it so can not use this API now.
|
|
|
|
|
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
|
|
|
|
|
|
|
|
|
// Alias for better readability.
|
|
|
|
|
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
|
|
|
|
#define CONV_BACKWARD_INPUT \
|
|
|
|
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
|
|
|
|
#define CONV_BACKWARD_WEIGHT \
|
|
|
|
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
|
|
|
|
|
|
|
|
|
struct ConvCacheKey {
|
|
|
|
|
int device_id;
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type;
|
|
|
|
|
cudnnDataType_t cudnn_type;
|
|
|
|
|
cudnnDataType_t cudnn_dtype;
|
|
|
|
|
std::array<int, MAX_NDIM> input_shape;
|
|
|
|
|
std::array<int, MAX_NDIM> filter_shape;
|
|
|
|
|
std::array<int, MAX_NDIM> weight_shape;
|
|
|
|
|
std::array<int, MAX_NDIM> stride;
|
|
|
|
|
std::array<int, MAX_NDIM> padding_lo;
|
|
|
|
|
std::array<int, MAX_NDIM> padding_hi;
|
|
|
|
|
std::array<int, MAX_NDIM> stride;
|
|
|
|
|
std::array<int, MAX_NDIM> dilation;
|
|
|
|
|
int groups;
|
|
|
|
|
bool flip;
|
|
|
|
|
uint8_t input_alignment;
|
|
|
|
|
uint8_t filter_alignment;
|
|
|
|
|
uint8_t weight_alignment;
|
|
|
|
|
uint8_t output_alignment;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto& conv_cache() {
|
|
|
|
|
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
|
|
|
|
/* capacity */ 128);
|
|
|
|
|
static LRUBytesKeyCache<
|
|
|
|
|
ConvCacheKey,
|
|
|
|
|
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
|
|
|
|
cache(/* capacity */ 128);
|
|
|
|
|
return cache;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -162,17 +170,17 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
|
|
|
|
bool execute_plan(
|
|
|
|
|
cu::CommandEncoder& encoder,
|
|
|
|
|
cudnn_frontend::ExecutionPlan& plan,
|
|
|
|
|
const array& in,
|
|
|
|
|
const array& wt,
|
|
|
|
|
array& out) {
|
|
|
|
|
array& x,
|
|
|
|
|
array& w,
|
|
|
|
|
array& y) {
|
|
|
|
|
int workspace_size = plan.getWorkspaceSize();
|
|
|
|
|
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
|
|
|
|
|
|
|
|
|
int64_t uids[3] = {'x', 'w', 'y'};
|
|
|
|
|
void* data_ptrs[3] = {
|
|
|
|
|
const_cast<void*>(in.data<void>()),
|
|
|
|
|
const_cast<void*>(wt.data<void>()),
|
|
|
|
|
out.data<void>(),
|
|
|
|
|
x.data<void>(),
|
|
|
|
|
w.data<void>(),
|
|
|
|
|
y.data<void>(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
|
|
|
@@ -212,46 +220,154 @@ bool execute_plan(
|
|
|
|
|
|
|
|
|
|
bool try_engines(
|
|
|
|
|
cu::CommandEncoder& encoder,
|
|
|
|
|
cudnn_frontend::EngineConfigList& configs,
|
|
|
|
|
const ConvCacheKey& cache_key,
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type,
|
|
|
|
|
cudnn_frontend::EngineConfigList& configs,
|
|
|
|
|
const std::string& op_graph_tag,
|
|
|
|
|
const array& in,
|
|
|
|
|
const array& wt,
|
|
|
|
|
array& out) {
|
|
|
|
|
array& x,
|
|
|
|
|
array& w,
|
|
|
|
|
array& y) {
|
|
|
|
|
for (auto& config : configs) {
|
|
|
|
|
try {
|
|
|
|
|
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
|
|
|
|
.setHandle(encoder.device().cudnn_handle())
|
|
|
|
|
.setEngineConfig(config, op_graph_tag)
|
|
|
|
|
.build();
|
|
|
|
|
if (execute_plan(encoder, plan, in, wt, out)) {
|
|
|
|
|
conv_cache().emplace(cache_key, std::move(plan));
|
|
|
|
|
if (execute_plan(encoder, plan, x, w, y)) {
|
|
|
|
|
conv_cache().emplace(
|
|
|
|
|
cache_key, std::make_pair(backend_type, std::move(plan)));
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} catch (cudnn_frontend::cudnnException&) {
|
|
|
|
|
} catch (cudnn_frontend::cudnnException& error) {
|
|
|
|
|
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
|
|
|
|
throw;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
auto get_conv_op_settings(
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type,
|
|
|
|
|
array& x,
|
|
|
|
|
array& w,
|
|
|
|
|
array& y,
|
|
|
|
|
const std::vector<int>& kernel_strides,
|
|
|
|
|
const std::vector<int>& padding_lo_,
|
|
|
|
|
const std::vector<int>& padding_hi_,
|
|
|
|
|
const std::vector<int>& kernel_dilation,
|
|
|
|
|
const std::vector<int>& input_dilation) {
|
|
|
|
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
|
|
|
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
|
|
|
|
|
|
|
|
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
|
|
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
|
|
|
|
if (out.size() == 0) {
|
|
|
|
|
return;
|
|
|
|
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
|
|
|
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
|
|
|
|
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
|
|
|
|
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
|
|
|
|
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
|
|
|
|
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
|
|
|
|
padding_hi[i] = out_size - in_size + padding_hi[i];
|
|
|
|
|
}
|
|
|
|
|
return std::make_tuple(
|
|
|
|
|
convert_vector<int64_t>(input_dilation),
|
|
|
|
|
std::move(padding_lo),
|
|
|
|
|
std::move(padding_hi),
|
|
|
|
|
convert_vector<int64_t>(kernel_dilation));
|
|
|
|
|
|
|
|
|
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
|
|
|
|
padding_hi = padding_lo;
|
|
|
|
|
return std::make_tuple(
|
|
|
|
|
convert_vector<int64_t>(kernel_dilation),
|
|
|
|
|
std::move(padding_lo),
|
|
|
|
|
std::move(padding_hi),
|
|
|
|
|
convert_vector<int64_t>(kernel_strides));
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
return std::make_tuple(
|
|
|
|
|
convert_vector<int64_t>(kernel_strides),
|
|
|
|
|
std::move(padding_lo),
|
|
|
|
|
std::move(padding_hi),
|
|
|
|
|
convert_vector<int64_t>(kernel_dilation));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
|
|
|
|
cu::CommandEncoder& encoder,
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type,
|
|
|
|
|
Dtype dtype,
|
|
|
|
|
array& x,
|
|
|
|
|
array& w,
|
|
|
|
|
array& y,
|
|
|
|
|
const std::vector<int64_t>& stride,
|
|
|
|
|
const std::vector<int64_t>& padding_lo,
|
|
|
|
|
const std::vector<int64_t>& padding_hi,
|
|
|
|
|
const std::vector<int64_t>& dilation) {
|
|
|
|
|
try {
|
|
|
|
|
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
|
|
|
|
? CUDNN_DATA_FLOAT
|
|
|
|
|
: dtype_to_cudnn_type(dtype);
|
|
|
|
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
|
|
|
|
.setDataType(compute_dtype)
|
|
|
|
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
|
|
|
|
.setNDims(stride.size())
|
|
|
|
|
.setStrides(stride.size(), stride.data())
|
|
|
|
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
|
|
|
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
|
|
|
|
.setDilation(dilation.size(), dilation.data())
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
|
|
|
|
.setxDesc(build_tensor('x', x))
|
|
|
|
|
.setwDesc(build_tensor('w', w))
|
|
|
|
|
.setyDesc(build_tensor('y', y))
|
|
|
|
|
.setcDesc(conv_desc)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
|
|
|
|
return cudnn_frontend::OperationGraphBuilder()
|
|
|
|
|
.setHandle(encoder.device().cudnn_handle())
|
|
|
|
|
.setOperationGraph(ops.size(), ops.data())
|
|
|
|
|
.build();
|
|
|
|
|
} catch (cudnn_frontend::cudnnException& error) {
|
|
|
|
|
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
|
|
|
|
throw;
|
|
|
|
|
}
|
|
|
|
|
return std::nullopt;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Do necessary transposes and copies to prepare the inputs and outputs for
|
|
|
|
|
// building the cuDNN conv op. It is safe to be called multiple times in one
|
|
|
|
|
// eval_gpu, with cost of possible redundant copies.
|
|
|
|
|
std::tuple<array, array, array> prepare_args(
|
|
|
|
|
cu::CommandEncoder& encoder,
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type,
|
|
|
|
|
array in,
|
|
|
|
|
array wt,
|
|
|
|
|
array out,
|
|
|
|
|
Stream s) {
|
|
|
|
|
// Transpose the args depending on the backend type.
|
|
|
|
|
// TODO: Handle groups.
|
|
|
|
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
|
|
|
|
wt = swapaxes_in_eval(wt, 0, -1);
|
|
|
|
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
|
|
|
|
in = swapaxes_in_eval(in, 0, -1);
|
|
|
|
|
wt = swapaxes_in_eval(wt, 0, -1);
|
|
|
|
|
// Create a contiguous array that shares the data with |out|, but with dim
|
|
|
|
|
// C_in and C_out swapped.
|
|
|
|
|
Shape shape(out.shape());
|
|
|
|
|
std::swap(shape.front(), shape.back());
|
|
|
|
|
Strides strides(shape.size(), 1);
|
|
|
|
|
for (int i = shape.size() - 2; i >= 0; --i) {
|
|
|
|
|
strides[i] = shape[i + 1] * strides[i + 1];
|
|
|
|
|
}
|
|
|
|
|
array intermediate(std::move(shape), out.dtype(), nullptr, {});
|
|
|
|
|
intermediate.copy_shared_buffer(
|
|
|
|
|
out, std::move(strides), {true, true, false}, out.data_size());
|
|
|
|
|
out = intermediate;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
array in = inputs[0];
|
|
|
|
|
array wt = inputs[1];
|
|
|
|
|
out.set_data(allocator::malloc(out.nbytes()));
|
|
|
|
|
|
|
|
|
|
auto& s = stream();
|
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
|
|
|
|
|
|
|
|
// cuDNN requires contiguous input.
|
|
|
|
|
// TODO: Handle NCHW format specially.
|
|
|
|
|
if (!in.flags().row_contiguous) {
|
|
|
|
|
in = contiguous_copy_gpu(in, s);
|
|
|
|
|
encoder.add_temporary(in);
|
|
|
|
@@ -261,80 +377,170 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
|
|
|
encoder.add_temporary(wt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return {std::move(in), std::move(wt), std::move(out)};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
|
|
|
|
inline std::tuple<array&, array&, array&> dispatch_args(
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type,
|
|
|
|
|
array& in,
|
|
|
|
|
array& wt,
|
|
|
|
|
array& out) {
|
|
|
|
|
switch (backend_type) {
|
|
|
|
|
case CONV_BACKWARD_INPUT:
|
|
|
|
|
return {out, wt, in};
|
|
|
|
|
case CONV_BACKWARD_WEIGHT:
|
|
|
|
|
return {in, out, wt};
|
|
|
|
|
default:
|
|
|
|
|
return {in, wt, out};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Register inputs and outputs before actually running conv op. Can only be
|
|
|
|
|
// called once per eval_gpu.
|
|
|
|
|
void register_args(
|
|
|
|
|
cu::CommandEncoder& encoder,
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type,
|
|
|
|
|
array& in,
|
|
|
|
|
array& wt,
|
|
|
|
|
array& intermediate_out,
|
|
|
|
|
array& final_out) {
|
|
|
|
|
encoder.set_input_array(in);
|
|
|
|
|
encoder.set_input_array(wt);
|
|
|
|
|
encoder.set_output_array(out);
|
|
|
|
|
encoder.set_output_array(final_out);
|
|
|
|
|
|
|
|
|
|
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
|
|
|
|
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
|
|
|
|
if (backend_type == CONV_BACKWARD_WEIGHT) {
|
|
|
|
|
// Turn |out| into a strided array, which will have C_in and C_out swapped
|
|
|
|
|
// in vjp and the final |grad_weight| will then be contiguous.
|
|
|
|
|
Strides strides = intermediate_out.strides();
|
|
|
|
|
std::swap(strides.front(), strides.back());
|
|
|
|
|
final_out.copy_shared_buffer(
|
|
|
|
|
intermediate_out,
|
|
|
|
|
std::move(strides),
|
|
|
|
|
{false, false, false},
|
|
|
|
|
intermediate_out.data_size());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|
|
|
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
|
|
|
|
if (out_.size() == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
|
array in = inputs[0];
|
|
|
|
|
array wt = inputs[1];
|
|
|
|
|
array out = out_;
|
|
|
|
|
out.set_data(allocator::malloc(out.nbytes()));
|
|
|
|
|
Dtype dtype = out.dtype();
|
|
|
|
|
|
|
|
|
|
auto& s = stream();
|
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
|
|
|
|
|
|
|
|
// Search cache.
|
|
|
|
|
ConvCacheKey cache_key{
|
|
|
|
|
encoder.device().cuda_device(),
|
|
|
|
|
backend_type,
|
|
|
|
|
cudnn_type,
|
|
|
|
|
dtype_to_cudnn_type(dtype),
|
|
|
|
|
fixed_vector(in.shape()),
|
|
|
|
|
fixed_vector(wt.shape()),
|
|
|
|
|
fixed_vector(kernel_strides_),
|
|
|
|
|
fixed_vector(padding_lo_),
|
|
|
|
|
fixed_vector(padding_hi_),
|
|
|
|
|
fixed_vector(kernel_strides_),
|
|
|
|
|
fixed_vector(kernel_dilation_),
|
|
|
|
|
groups_,
|
|
|
|
|
flip_,
|
|
|
|
|
get_alignment(in),
|
|
|
|
|
get_alignment(wt),
|
|
|
|
|
get_alignment(out)};
|
|
|
|
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
|
|
|
|
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
|
|
|
|
throw std::runtime_error("Cached convolution plan failed to execute.");
|
|
|
|
|
auto& [backend_type, plan] = it->second;
|
|
|
|
|
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
|
|
|
|
register_args(encoder, backend_type, in, wt, out, out_);
|
|
|
|
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
|
|
|
|
if (!execute_plan(encoder, plan, x, w, y)) {
|
|
|
|
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build operation graph.
|
|
|
|
|
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
|
|
|
|
? CUDNN_DATA_FLOAT
|
|
|
|
|
: cudnn_type;
|
|
|
|
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
|
|
|
|
// convolution, so we make a best guess and then try.
|
|
|
|
|
std::vector<cudnnBackendDescriptorType_t> try_backends;
|
|
|
|
|
if (flip_) {
|
|
|
|
|
// When weight is flipped, we assume it is backward input convolution.
|
|
|
|
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
|
|
|
|
} else {
|
|
|
|
|
// Otherwise it could be backward weight convolution or forward convolution,
|
|
|
|
|
// mathematically there is no difference so we have to use heuristics.
|
|
|
|
|
// Empirically backward convolutions have large kernel dimensions, and
|
|
|
|
|
// usually have |in| and |wt| transposed.
|
|
|
|
|
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
|
|
|
|
|
wt.shape(2) > out.shape(2)) {
|
|
|
|
|
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
|
|
|
|
|
} else {
|
|
|
|
|
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto stride = convert_vector<int64_t>(kernel_strides_);
|
|
|
|
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
|
|
|
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
|
|
|
|
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
|
|
|
|
// Try to build op graph.
|
|
|
|
|
cudnnBackendDescriptorType_t backend_type;
|
|
|
|
|
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
|
|
|
|
for (auto try_backend : try_backends) {
|
|
|
|
|
auto [in_copy, wt_copy, out_copy] =
|
|
|
|
|
prepare_args(encoder, try_backend, in, wt, out, s);
|
|
|
|
|
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
|
|
|
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
|
|
|
|
try_backend,
|
|
|
|
|
x,
|
|
|
|
|
w,
|
|
|
|
|
y,
|
|
|
|
|
kernel_strides_,
|
|
|
|
|
padding_lo_,
|
|
|
|
|
padding_hi_,
|
|
|
|
|
kernel_dilation_,
|
|
|
|
|
input_dilation_);
|
|
|
|
|
op_graph = build_op_graph(
|
|
|
|
|
encoder,
|
|
|
|
|
try_backend,
|
|
|
|
|
dtype,
|
|
|
|
|
x,
|
|
|
|
|
w,
|
|
|
|
|
y,
|
|
|
|
|
stride,
|
|
|
|
|
padding_lo,
|
|
|
|
|
padding_hi,
|
|
|
|
|
dilation);
|
|
|
|
|
if (op_graph) {
|
|
|
|
|
backend_type = try_backend;
|
|
|
|
|
in = std::move(in_copy);
|
|
|
|
|
wt = std::move(wt_copy);
|
|
|
|
|
out = std::move(out_copy);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!op_graph) {
|
|
|
|
|
throw std::runtime_error("[conv] Can not build op graph.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
|
|
|
|
.setDataType(compute_data_type)
|
|
|
|
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
|
|
|
|
.setNDims(stride.size())
|
|
|
|
|
.setStrides(stride.size(), stride.data())
|
|
|
|
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
|
|
|
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
|
|
|
|
.setDilation(dilation.size(), dilation.data())
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
|
|
|
|
.setxDesc(build_tensor('x', in))
|
|
|
|
|
.setwDesc(build_tensor('w', wt))
|
|
|
|
|
.setyDesc(build_tensor('y', out))
|
|
|
|
|
.setcDesc(conv_desc)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
|
|
|
|
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
|
|
|
|
.setHandle(encoder.device().cudnn_handle())
|
|
|
|
|
.setOperationGraph(ops.size(), ops.data())
|
|
|
|
|
.build();
|
|
|
|
|
// Get ready to execute the graph.
|
|
|
|
|
register_args(encoder, backend_type, in, wt, out, out_);
|
|
|
|
|
|
|
|
|
|
// Try to run plans based on heuristics.
|
|
|
|
|
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
|
|
|
|
auto op_graph_tag = op_graph.getTag();
|
|
|
|
|
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
|
|
|
|
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
|
|
|
|
auto tag = op_graph->getTag();
|
|
|
|
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
|
|
|
|
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// Then try fallback plans.
|
|
|
|
|
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
|
|
|
|
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
|
|
|
|
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
|
|
|
|
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
throw std::runtime_error("Unable to find an engine for convolution.");
|
|
|
|
|
throw std::runtime_error("[conv] Unable to find a working engine.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace mlx::core
|
|
|
|
|