Compare commits

...

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
86258f292f [CUDA] Vectorize generated kernels (#2444) 2025-07-31 18:18:57 -07:00
Cheng
b26d88591c [CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
2025-08-01 10:16:06 +09:00
Cheng
86c6a15571 [CUDA] Backward convolution (#2431) 2025-08-01 09:54:05 +09:00
11 changed files with 392 additions and 148 deletions

View File

@@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
} }
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -196,6 +196,9 @@ void shared_buffer_reshape(
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T> template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) { inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index)); vec.erase(std::next(vec.begin(), index));

View File

@@ -104,10 +104,41 @@ struct FusedKernelBuilder {
" }\n"; " }\n";
} }
// Vectorized read loop
if (contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_scalar(x) || is_constant(i)) {
continue;
}
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
xname,
type);
}
}
// Create some space for the outputs
for (const auto& x : outputs) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
}
// Work loop // Work loop
os += if (!contiguous) {
"\n" os +=
" for (int i = 0; i < work_per_thread && index < size; i++) {\n"; "\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
} else {
os +=
"\n"
" #pragma unroll\n"
" for (int i = 0; i < work_per_thread; i++) {\n";
}
// Read inputs. // Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@@ -122,7 +153,7 @@ struct FusedKernelBuilder {
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname); value = fmt::format("{}[0]", xname);
} else if (contiguous) { } else if (contiguous) {
value = fmt::format("{}[index]", xname); value = fmt::format("vec_{}[i]", xname);
} else { } else {
value = fmt::format("{}[{}_idx]", xname, xname); value = fmt::format("{}[{}_idx]", xname, xname);
} }
@@ -150,25 +181,30 @@ struct FusedKernelBuilder {
// Write output. // Write output.
for (const auto& x : outputs) { for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
} }
// End of work loop // End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) { if (!contiguous) {
os += "\n";
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i]; const auto& x = inputs[i];
const std::string& xname = namer.get_name(x); const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) { if (is_scalar(x) || is_constant(i)) {
continue; continue;
} }
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n"; os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
} }
} }
os += " }\n"; os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
namer.get_name(x));
}
os += "}\n"; os += "}\n";
} }
}; };
@@ -192,6 +228,15 @@ void Compiled::eval_gpu(
nvtx3::scoped_range r("Compiled::eval_gpu"); nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream(); auto& s = stream();
// Determine the work per thread for the vectorized reads/writes. We take it
// as 16 over the max itemsize for the outputs. Another heuristic could be
// over the max itemsize of all arrays.
int max_size = 1;
for (const auto& x : outputs) {
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
}
int work_per_thread = 16 / max_size;
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code. // Build source code.
cu::FusedKernelBuilder builder{ cu::FusedKernelBuilder builder{
@@ -205,28 +250,23 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n"; builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names. // Build kernel names.
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) { kernel_names.push_back(fmt::format(
kernel_names.push_back(fmt::format( "mlx::core::cu::{}_contiguous<uint32_t, {}>",
"mlx::core::cu::{}_contiguous<uint32_t, {}>", lib_name(),
lib_name(), work_per_thread));
work_per_thread)); kernel_names.push_back(fmt::format(
kernel_names.push_back(fmt::format( "mlx::core::cu::{}_contiguous<int64_t, {}>",
"mlx::core::cu::{}_contiguous<int64_t, {}>", lib_name(),
lib_name(), work_per_thread));
work_per_thread)); for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
for (int i = 1; i <= MAX_NDIM; ++i) { for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", "mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>", "mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
lib_name(),
i,
work_per_thread));
} }
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names)); return std::make_pair(std::move(builder.os), std::move(kernel_names));
}); });
@@ -269,7 +309,6 @@ void Compiled::eval_gpu(
} }
// Choose work per thread // Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) { if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1; work_per_thread = 1;
} }

View File

@@ -16,7 +16,6 @@
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
#include <numeric>
namespace mlx::core { namespace mlx::core {
@@ -25,25 +24,34 @@ namespace {
// Not all engines support it so can not use this API now. // Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 #define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
struct ConvCacheKey { struct ConvCacheKey {
int device_id; int device_id;
cudnnBackendDescriptorType_t backend_type; cudnnDataType_t cudnn_dtype;
cudnnDataType_t cudnn_type;
std::array<int, MAX_NDIM> input_shape; std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> filter_shape; std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> padding_lo; std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi; std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> dilation; std::array<int, MAX_NDIM> dilation;
int groups; int groups;
bool flip;
uint8_t input_alignment; uint8_t input_alignment;
uint8_t filter_alignment; uint8_t weight_alignment;
uint8_t output_alignment; uint8_t output_alignment;
}; };
auto& conv_cache() { auto& conv_cache() {
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache( static LRUBytesKeyCache<
/* capacity */ 128); ConvCacheKey,
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
cache(/* capacity */ 128);
return cache; return cache;
} }
@@ -162,17 +170,17 @@ cudnn_frontend::EngineConfigList get_engine_configs(
bool execute_plan( bool execute_plan(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan, cudnn_frontend::ExecutionPlan& plan,
const array& in, array& x,
const array& wt, array& w,
array& out) { array& y) {
int workspace_size = plan.getWorkspaceSize(); int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'}; int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = { void* data_ptrs[3] = {
const_cast<void*>(in.data<void>()), x.data<void>(),
const_cast<void*>(wt.data<void>()), w.data<void>(),
out.data<void>(), y.data<void>(),
}; };
auto variantPack = cudnn_frontend::VariantPackBuilder() auto variantPack = cudnn_frontend::VariantPackBuilder()
@@ -212,46 +220,154 @@ bool execute_plan(
bool try_engines( bool try_engines(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs,
const ConvCacheKey& cache_key, const ConvCacheKey& cache_key,
cudnnBackendDescriptorType_t backend_type,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag, const std::string& op_graph_tag,
const array& in, array& x,
const array& wt, array& w,
array& out) { array& y) {
for (auto& config : configs) { for (auto& config : configs) {
try { try {
auto plan = cudnn_frontend::ExecutionPlanBuilder() auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle()) .setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag) .setEngineConfig(config, op_graph_tag)
.build(); .build();
if (execute_plan(encoder, plan, in, wt, out)) { if (execute_plan(encoder, plan, x, w, y)) {
conv_cache().emplace(cache_key, std::move(plan)); conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(plan)));
return true; return true;
} }
} catch (cudnn_frontend::cudnnException&) { } catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
} }
} }
return false; return false;
} }
} // namespace auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
array& x,
array& w,
array& y,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding_lo_,
const std::vector<int>& padding_hi_,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { if (backend_type == CONV_BACKWARD_INPUT) {
nvtx3::scoped_range r("Convolution::eval_gpu"); for (int i = 0; i < padding_lo.size(); ++i) {
if (out.size() == 0) { int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
return; padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
convert_vector<int64_t>(input_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
padding_hi = padding_lo;
return std::make_tuple(
convert_vector<int64_t>(kernel_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_strides));
} else {
return std::make_tuple(
convert_vector<int64_t>(kernel_strides),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
}
}
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', x))
.setwDesc(build_tensor('w', w))
.setyDesc(build_tensor('y', y))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
return std::nullopt;
}
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array in,
array wt,
array out,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = swapaxes_in_eval(wt, 0, -1);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = swapaxes_in_eval(in, 0, -1);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
Shape shape(out.shape());
std::swap(shape.front(), shape.back());
Strides strides(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
array intermediate(std::move(shape), out.dtype(), nullptr, {});
intermediate.copy_shared_buffer(
out, std::move(strides), {true, true, false}, out.data_size());
out = intermediate;
} }
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// cuDNN requires contiguous input. // cuDNN requires contiguous input.
// TODO: Handle NCHW format specially.
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s); in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in); encoder.add_temporary(in);
@@ -261,80 +377,170 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporary(wt); encoder.add_temporary(wt);
} }
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& intermediate_out,
array& final_out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(final_out);
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; if (backend_type == CONV_BACKWARD_WEIGHT) {
auto cudnn_type = dtype_to_cudnn_type(in.dtype()); // Turn |out| into a strided array, which will have C_in and C_out swapped
// in vjp and the final |grad_weight| will then be contiguous.
Strides strides = intermediate_out.strides();
std::swap(strides.front(), strides.back());
final_out.copy_shared_buffer(
intermediate_out,
std::move(strides),
{false, false, false},
intermediate_out.data_size());
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out_.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache. // Search cache.
ConvCacheKey cache_key{ ConvCacheKey cache_key{
encoder.device().cuda_device(), encoder.device().cuda_device(),
backend_type, dtype_to_cudnn_type(dtype),
cudnn_type,
fixed_vector(in.shape()), fixed_vector(in.shape()),
fixed_vector(wt.shape()), fixed_vector(wt.shape()),
fixed_vector(kernel_strides_),
fixed_vector(padding_lo_), fixed_vector(padding_lo_),
fixed_vector(padding_hi_), fixed_vector(padding_hi_),
fixed_vector(kernel_strides_),
fixed_vector(kernel_dilation_), fixed_vector(kernel_dilation_),
groups_, groups_,
flip_,
get_alignment(in), get_alignment(in),
get_alignment(wt), get_alignment(wt),
get_alignment(out)}; get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) { auto& [backend_type, plan] = it->second;
throw std::runtime_error("Cached convolution plan failed to execute."); std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
} }
return; return;
} }
// Build operation graph. // There is no reliable way to deduce the proper cuDNN backend for the
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) // convolution, so we make a best guess and then try.
? CUDNN_DATA_FLOAT std::vector<cudnnBackendDescriptorType_t> try_backends;
: cudnn_type; if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
} else {
// Otherwise it could be backward weight convolution or forward convolution,
// mathematically there is no difference so we have to use heuristics.
// Empirically backward convolutions have large kernel dimensions, and
// usually have |in| and |wt| transposed.
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
wt.shape(2) > out.shape(2)) {
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
} else {
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
}
}
auto stride = convert_vector<int64_t>(kernel_strides_); // Try to build op graph.
auto padding_lo = convert_vector<int64_t>(padding_lo_); cudnnBackendDescriptorType_t backend_type;
auto padding_hi = convert_vector<int64_t>(padding_hi_); std::optional<cudnn_frontend::OperationGraph> op_graph;
auto dilation = convert_vector<int64_t>(kernel_dilation_); for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,
x,
w,
y,
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_op_graph(
encoder,
try_backend,
dtype,
x,
w,
y,
stride,
padding_lo,
padding_hi,
dilation);
if (op_graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
break;
}
}
if (!op_graph) {
throw std::runtime_error("[conv] Can not build op graph.");
}
auto conv_desc = cudnn_frontend::ConvDescBuilder() // Get ready to execute the graph.
.setDataType(compute_data_type) register_args(encoder, backend_type, in, wt, out, out_);
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt))
.setyDesc(build_tensor('y', out))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
// Try to run plans based on heuristics. // Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph); auto configs = get_engine_configs(backend_type, dtype, *op_graph);
auto op_graph_tag = op_graph.getTag(); auto tag = op_graph->getTag();
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return; return;
} }
// Then try fallback plans. // Then try fallback plans.
configs = get_engine_configs(backend_type, in.dtype(), op_graph); configs = get_engine_configs(backend_type, dtype, *op_graph);
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return; return;
} }
throw std::runtime_error("Unable to find an engine for convolution."); throw std::runtime_error("[conv] Unable to find a working engine.");
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -269,6 +269,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
} }
void CommandEncoder::commit() { void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
} }

View File

@@ -36,18 +36,15 @@ void eval(array& arr) {
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running. // Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr()); // Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
}
} }
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr()); encoder.add_temporary(s);
} }
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit(); encoder.maybe_commit();
} }

View File

@@ -5,6 +5,8 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
@@ -42,6 +44,7 @@ inline array ensure_row_contiguous_matrix(
void fast::AffineQuantize::eval_gpu( void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream(); auto& s = stream();
auto& d = cu::device(s.device); auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s); auto& enc = d.get_command_encoder(s);

View File

@@ -13,7 +13,6 @@
#include <cub/device/device_segmented_sort.cuh> #include <cub/device/device_segmented_sort.cuh>
#include <cassert> #include <cassert>
#include <numeric>
namespace mlx::core { namespace mlx::core {
@@ -27,29 +26,6 @@ struct ModOp {
} }
}; };
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
struct OffsetTransform { struct OffsetTransform {
int nsort; int nsort;

View File

@@ -192,7 +192,7 @@ void ternary_op_gpu(
} }
void Select::eval_gpu(const std::vector<array>& inputs, array& out) { void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("select::eval_gpu"); nvtx3::scoped_range r("Select::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
ternary_op_gpu<cu::Select>(inputs, out, s); ternary_op_gpu<cu::Select>(inputs, out, s);
} }

View File

@@ -133,6 +133,7 @@ void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) { void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Pad::eval_gpu");
// Inputs must be base input array and scalar val array // Inputs must be base input array and scalar val array
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& in = inputs[0]; auto& in = inputs[0];

View File

@@ -18,21 +18,12 @@ cuda_skip = {
"TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad", "TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad", "TestConv.test_conv_groups_grad",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D", "TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general", "TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_1D",
"TestConvTranspose.test_torch_conv_transpose_1D_grad", "TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D",
"TestConvTranspose.test_torch_conv_transpose_2D_grad", "TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",