diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index ae169e35e2..4c9e39dc63 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -228,4 +228,31 @@ std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); } +array swapaxes_in_eval(const array& x, int axis1, int axis2) { + int ndim = x.ndim(); + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + + auto shape = x.shape(); + std::swap(shape[axis1], shape[axis2]); + auto strides = x.strides(); + std::swap(strides[axis1], strides[axis2]); + + auto [data_size, row_contiguous, col_contiguous] = + check_contiguity(shape, strides); + bool contiguous = data_size == x.data_size(); + + array out(std::move(shape), x.dtype(), nullptr, {}); + out.copy_shared_buffer( + x, + std::move(strides), + {contiguous, row_contiguous, col_contiguous}, + x.data_size()); + return out; +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 0f9846086a..04d804238d 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -196,6 +196,9 @@ void shared_buffer_reshape( const Strides& out_strides, array& out); +// Like the swapaxes op but safe to call in eval_gpu. +array swapaxes_in_eval(const array& x, int axis1, int axis2); + template inline std::vector remove_index(std::vector vec, size_t index) { vec.erase(std::next(vec.begin(), index)); diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 578b520a04..9833c348b3 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -16,7 +16,6 @@ #include #include -#include namespace mlx::core { @@ -25,25 +24,34 @@ namespace { // Not all engines support it so can not use this API now. #define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 +// Alias for better readability. +#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR +#define CONV_BACKWARD_INPUT \ + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR +#define CONV_BACKWARD_WEIGHT \ + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR + struct ConvCacheKey { int device_id; - cudnnBackendDescriptorType_t backend_type; - cudnnDataType_t cudnn_type; + cudnnDataType_t cudnn_dtype; std::array input_shape; - std::array filter_shape; + std::array weight_shape; + std::array stride; std::array padding_lo; std::array padding_hi; - std::array stride; std::array dilation; int groups; + bool flip; uint8_t input_alignment; - uint8_t filter_alignment; + uint8_t weight_alignment; uint8_t output_alignment; }; auto& conv_cache() { - static LRUBytesKeyCache cache( - /* capacity */ 128); + static LRUBytesKeyCache< + ConvCacheKey, + std::pair> + cache(/* capacity */ 128); return cache; } @@ -162,17 +170,17 @@ cudnn_frontend::EngineConfigList get_engine_configs( bool execute_plan( cu::CommandEncoder& encoder, cudnn_frontend::ExecutionPlan& plan, - const array& in, - const array& wt, - array& out) { + array& x, + array& w, + array& y) { int workspace_size = plan.getWorkspaceSize(); array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); int64_t uids[3] = {'x', 'w', 'y'}; void* data_ptrs[3] = { - const_cast(in.data()), - const_cast(wt.data()), - out.data(), + x.data(), + w.data(), + y.data(), }; auto variantPack = cudnn_frontend::VariantPackBuilder() @@ -212,46 +220,154 @@ bool execute_plan( bool try_engines( cu::CommandEncoder& encoder, - cudnn_frontend::EngineConfigList& configs, const ConvCacheKey& cache_key, + cudnnBackendDescriptorType_t backend_type, + cudnn_frontend::EngineConfigList& configs, const std::string& op_graph_tag, - const array& in, - const array& wt, - array& out) { + array& x, + array& w, + array& y) { for (auto& config : configs) { try { auto plan = cudnn_frontend::ExecutionPlanBuilder() .setHandle(encoder.device().cudnn_handle()) .setEngineConfig(config, op_graph_tag) .build(); - if (execute_plan(encoder, plan, in, wt, out)) { - conv_cache().emplace(cache_key, std::move(plan)); + if (execute_plan(encoder, plan, x, w, y)) { + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(plan))); return true; } - } catch (cudnn_frontend::cudnnException&) { + } catch (cudnn_frontend::cudnnException& error) { + if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) { + throw; + } } } return false; } -} // namespace +auto get_conv_op_settings( + cudnnBackendDescriptorType_t backend_type, + array& x, + array& w, + array& y, + const std::vector& kernel_strides, + const std::vector& padding_lo_, + const std::vector& padding_hi_, + const std::vector& kernel_dilation, + const std::vector& input_dilation) { + auto padding_lo = convert_vector(padding_lo_); + auto padding_hi = convert_vector(padding_hi_); -void Convolution::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Convolution::eval_gpu"); - if (out.size() == 0) { - return; + if (backend_type == CONV_BACKWARD_INPUT) { + for (int i = 0; i < padding_lo.size(); ++i) { + int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1); + padding_lo[i] = wt_size - padding_lo[i] - 1; + int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1); + int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1); + padding_hi[i] = out_size - in_size + padding_hi[i]; + } + return std::make_tuple( + convert_vector(input_dilation), + std::move(padding_lo), + std::move(padding_hi), + convert_vector(kernel_dilation)); + + } else if (backend_type == CONV_BACKWARD_WEIGHT) { + padding_hi = padding_lo; + return std::make_tuple( + convert_vector(kernel_dilation), + std::move(padding_lo), + std::move(padding_hi), + convert_vector(kernel_strides)); + + } else { + return std::make_tuple( + convert_vector(kernel_strides), + std::move(padding_lo), + std::move(padding_hi), + convert_vector(kernel_dilation)); + } +} + +std::optional build_op_graph( + cu::CommandEncoder& encoder, + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + array& x, + array& w, + array& y, + const std::vector& stride, + const std::vector& padding_lo, + const std::vector& padding_hi, + const std::vector& dilation) { + try { + auto compute_dtype = (dtype == float16 || dtype == bfloat16) + ? CUDNN_DATA_FLOAT + : dtype_to_cudnn_type(dtype); + auto conv_desc = cudnn_frontend::ConvDescBuilder() + .setDataType(compute_dtype) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(stride.size()) + .setStrides(stride.size(), stride.data()) + .setPrePadding(padding_lo.size(), padding_lo.data()) + .setPostPadding(padding_hi.size(), padding_hi.data()) + .setDilation(dilation.size(), dilation.data()) + .build(); + + auto op = cudnn_frontend::OperationBuilder(backend_type) + .setxDesc(build_tensor('x', x)) + .setwDesc(build_tensor('w', w)) + .setyDesc(build_tensor('y', y)) + .setcDesc(conv_desc) + .build(); + + std::array ops = {&op}; + return cudnn_frontend::OperationGraphBuilder() + .setHandle(encoder.device().cudnn_handle()) + .setOperationGraph(ops.size(), ops.data()) + .build(); + } catch (cudnn_frontend::cudnnException& error) { + if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) { + throw; + } + return std::nullopt; + } +} + +// Do necessary transposes and copies to prepare the inputs and outputs for +// building the cuDNN conv op. It is safe to be called multiple times in one +// eval_gpu, with cost of possible redundant copies. +std::tuple prepare_args( + cu::CommandEncoder& encoder, + cudnnBackendDescriptorType_t backend_type, + array in, + array wt, + array out, + Stream s) { + // Transpose the args depending on the backend type. + // TODO: Handle groups. + if (backend_type == CONV_BACKWARD_INPUT) { + wt = swapaxes_in_eval(wt, 0, -1); + } else if (backend_type == CONV_BACKWARD_WEIGHT) { + in = swapaxes_in_eval(in, 0, -1); + wt = swapaxes_in_eval(wt, 0, -1); + // Create a contiguous array that shares the data with |out|, but with dim + // C_in and C_out swapped. + Shape shape(out.shape()); + std::swap(shape.front(), shape.back()); + Strides strides(shape.size(), 1); + for (int i = shape.size() - 2; i >= 0; --i) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + array intermediate(std::move(shape), out.dtype(), nullptr, {}); + intermediate.copy_shared_buffer( + out, std::move(strides), {true, true, false}, out.data_size()); + out = intermediate; } - assert(inputs.size() == 2); - array in = inputs[0]; - array wt = inputs[1]; - out.set_data(allocator::malloc(out.nbytes())); - - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - // cuDNN requires contiguous input. - // TODO: Handle NCHW format specially. if (!in.flags().row_contiguous) { in = contiguous_copy_gpu(in, s); encoder.add_temporary(in); @@ -261,80 +377,170 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { encoder.add_temporary(wt); } + return {std::move(in), std::move(wt), std::move(out)}; +} + +// Get the x/w/y args from the in/wt/out args depending on backend type. +inline std::tuple dispatch_args( + cudnnBackendDescriptorType_t backend_type, + array& in, + array& wt, + array& out) { + switch (backend_type) { + case CONV_BACKWARD_INPUT: + return {out, wt, in}; + case CONV_BACKWARD_WEIGHT: + return {in, out, wt}; + default: + return {in, wt, out}; + } +} + +// Register inputs and outputs before actually running conv op. Can only be +// called once per eval_gpu. +void register_args( + cu::CommandEncoder& encoder, + cudnnBackendDescriptorType_t backend_type, + array& in, + array& wt, + array& intermediate_out, + array& final_out) { encoder.set_input_array(in); encoder.set_input_array(wt); - encoder.set_output_array(out); + encoder.set_output_array(final_out); - auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; - auto cudnn_type = dtype_to_cudnn_type(in.dtype()); + if (backend_type == CONV_BACKWARD_WEIGHT) { + // Turn |out| into a strided array, which will have C_in and C_out swapped + // in vjp and the final |grad_weight| will then be contiguous. + Strides strides = intermediate_out.strides(); + std::swap(strides.front(), strides.back()); + final_out.copy_shared_buffer( + intermediate_out, + std::move(strides), + {false, false, false}, + intermediate_out.data_size()); + } +} + +} // namespace + +void Convolution::eval_gpu(const std::vector& inputs, array& out_) { + nvtx3::scoped_range r("Convolution::eval_gpu"); + if (out_.size() == 0) { + return; + } + + assert(inputs.size() == 2); + array in = inputs[0]; + array wt = inputs[1]; + array out = out_; + out.set_data(allocator::malloc(out.nbytes())); + Dtype dtype = out.dtype(); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); // Search cache. ConvCacheKey cache_key{ encoder.device().cuda_device(), - backend_type, - cudnn_type, + dtype_to_cudnn_type(dtype), fixed_vector(in.shape()), fixed_vector(wt.shape()), + fixed_vector(kernel_strides_), fixed_vector(padding_lo_), fixed_vector(padding_hi_), - fixed_vector(kernel_strides_), fixed_vector(kernel_dilation_), groups_, + flip_, get_alignment(in), get_alignment(wt), get_alignment(out)}; if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { - if (!execute_plan(encoder, it->second, in, wt, out)) { - throw std::runtime_error("Cached convolution plan failed to execute."); + auto& [backend_type, plan] = it->second; + std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s); + register_args(encoder, backend_type, in, wt, out, out_); + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (!execute_plan(encoder, plan, x, w, y)) { + throw std::runtime_error("[conv] Cached plan failed to execute."); } return; } - // Build operation graph. - auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) - ? CUDNN_DATA_FLOAT - : cudnn_type; + // There is no reliable way to deduce the proper cuDNN backend for the + // convolution, so we make a best guess and then try. + std::vector try_backends; + if (flip_) { + // When weight is flipped, we assume it is backward input convolution. + try_backends.push_back(CONV_BACKWARD_INPUT); + } else { + // Otherwise it could be backward weight convolution or forward convolution, + // mathematically there is no difference so we have to use heuristics. + // Empirically backward convolutions have large kernel dimensions, and + // usually have |in| and |wt| transposed. + if (!in.flags().row_contiguous && !wt.flags().row_contiguous && + wt.shape(2) > out.shape(2)) { + try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD}; + } else { + try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT}; + } + } - auto stride = convert_vector(kernel_strides_); - auto padding_lo = convert_vector(padding_lo_); - auto padding_hi = convert_vector(padding_hi_); - auto dilation = convert_vector(kernel_dilation_); + // Try to build op graph. + cudnnBackendDescriptorType_t backend_type; + std::optional op_graph; + for (auto try_backend : try_backends) { + auto [in_copy, wt_copy, out_copy] = + prepare_args(encoder, try_backend, in, wt, out, s); + auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy); + auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings( + try_backend, + x, + w, + y, + kernel_strides_, + padding_lo_, + padding_hi_, + kernel_dilation_, + input_dilation_); + op_graph = build_op_graph( + encoder, + try_backend, + dtype, + x, + w, + y, + stride, + padding_lo, + padding_hi, + dilation); + if (op_graph) { + backend_type = try_backend; + in = std::move(in_copy); + wt = std::move(wt_copy); + out = std::move(out_copy); + break; + } + } + if (!op_graph) { + throw std::runtime_error("[conv] Can not build op graph."); + } - auto conv_desc = cudnn_frontend::ConvDescBuilder() - .setDataType(compute_data_type) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(stride.size()) - .setStrides(stride.size(), stride.data()) - .setPrePadding(padding_lo.size(), padding_lo.data()) - .setPostPadding(padding_hi.size(), padding_hi.data()) - .setDilation(dilation.size(), dilation.data()) - .build(); - - auto op = cudnn_frontend::OperationBuilder(backend_type) - .setxDesc(build_tensor('x', in)) - .setwDesc(build_tensor('w', wt)) - .setyDesc(build_tensor('y', out)) - .setcDesc(conv_desc) - .build(); - - std::array ops = {&op}; - auto op_graph = cudnn_frontend::OperationGraphBuilder() - .setHandle(encoder.device().cudnn_handle()) - .setOperationGraph(ops.size(), ops.data()) - .build(); + // Get ready to execute the graph. + register_args(encoder, backend_type, in, wt, out, out_); // Try to run plans based on heuristics. - auto configs = get_engine_configs(backend_type, in.dtype(), op_graph); - auto op_graph_tag = op_graph.getTag(); - if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { + auto configs = get_engine_configs(backend_type, dtype, *op_graph); + auto tag = op_graph->getTag(); + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) { return; } // Then try fallback plans. - configs = get_engine_configs(backend_type, in.dtype(), op_graph); - if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { + configs = get_engine_configs(backend_type, dtype, *op_graph); + if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) { return; } - throw std::runtime_error("Unable to find an engine for convolution."); + throw std::runtime_error("[conv] Unable to find a working engine."); } } // namespace mlx::core diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index e8e0f389c6..5bbd72fd53 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -13,7 +13,6 @@ #include #include -#include namespace mlx::core { @@ -27,29 +26,6 @@ struct ModOp { } }; -// We can not use any op in eval, make an utility. -array swapaxes_in_eval(const array& in, int axis1, int axis2) { - std::vector axes(in.ndim()); - std::iota(axes.begin(), axes.end(), 0); - std::swap(axes[axis1], axes[axis2]); - // TODO: Share the code with Transpose::eval. - Shape shape(axes.size()); - Strides strides(in.ndim()); - for (size_t ax = 0; ax < axes.size(); ++ax) { - shape[ax] = in.shape()[axes[ax]]; - strides[ax] = in.strides()[axes[ax]]; - } - auto flags = in.flags(); - if (flags.contiguous) { - auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides); - flags.row_contiguous = row_contiguous; - flags.col_contiguous = col_contiguous; - } - array out(shape, in.dtype(), nullptr, {}); - out.copy_shared_buffer(in, strides, flags, in.data_size()); - return out; -} - struct OffsetTransform { int nsort; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index e14aa675e0..0f57100e82 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -18,21 +18,12 @@ cuda_skip = { "TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_general_flip_grad", "TestConv.test_conv_groups_grad", - "TestConv.test_torch_conv_1D_grad", "TestConv.test_torch_conv_2D", - "TestConv.test_torch_conv_2D_grad", - "TestConv.test_torch_conv_3D_grad", "TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_general", - "TestConvTranspose.test_torch_conv_tranpose_1d_output_padding", - "TestConvTranspose.test_torch_conv_transpose_1D", "TestConvTranspose.test_torch_conv_transpose_1D_grad", - "TestConvTranspose.test_torch_conv_transpose_2D", "TestConvTranspose.test_torch_conv_transpose_2D_grad", - "TestConvTranspose.test_torch_conv_transpose_2d_output_padding", - "TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D_grad", - "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two",