mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
8b25ce62d5
...
86258f292f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 |
@@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -196,6 +196,9 @@ void shared_buffer_reshape(
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
// Like the swapaxes op but safe to call in eval_gpu.
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
|
||||
@@ -104,10 +104,41 @@ struct FusedKernelBuilder {
|
||||
" }\n";
|
||||
}
|
||||
|
||||
// Vectorized read loop
|
||||
if (contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
os += fmt::format(
|
||||
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
|
||||
xname,
|
||||
type);
|
||||
}
|
||||
}
|
||||
|
||||
// Create some space for the outputs
|
||||
for (const auto& x : outputs) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
os += fmt::format(
|
||||
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
|
||||
}
|
||||
|
||||
// Work loop
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
if (!contiguous) {
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
} else {
|
||||
os +=
|
||||
"\n"
|
||||
" #pragma unroll\n"
|
||||
" for (int i = 0; i < work_per_thread; i++) {\n";
|
||||
}
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
@@ -122,7 +153,7 @@ struct FusedKernelBuilder {
|
||||
} else if (is_scalar(x)) {
|
||||
value = fmt::format("{}[0]", xname);
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
value = fmt::format("vec_{}[i]", xname);
|
||||
} else {
|
||||
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||
}
|
||||
@@ -150,25 +181,30 @@ struct FusedKernelBuilder {
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
// End of work loop
|
||||
os +=
|
||||
"\n"
|
||||
" index++;\n";
|
||||
if (!contiguous) {
|
||||
os += "\n";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
|
||||
}
|
||||
}
|
||||
os += " }\n";
|
||||
|
||||
// Store the output to global memory
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(
|
||||
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||
namer.get_name(x));
|
||||
}
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
@@ -192,6 +228,15 @@ void Compiled::eval_gpu(
|
||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
// Determine the work per thread for the vectorized reads/writes. We take it
|
||||
// as 16 over the max itemsize for the outputs. Another heuristic could be
|
||||
// over the max itemsize of all arrays.
|
||||
int max_size = 1;
|
||||
for (const auto& x : outputs) {
|
||||
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
|
||||
}
|
||||
int work_per_thread = 16 / max_size;
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||
// Build source code.
|
||||
cu::FusedKernelBuilder builder{
|
||||
@@ -205,28 +250,23 @@ void Compiled::eval_gpu(
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names;
|
||||
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||
lib_name(),
|
||||
i,
|
||||
work_per_thread));
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
@@ -269,7 +309,6 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
// Choose work per thread
|
||||
int work_per_thread = 4;
|
||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -25,25 +24,34 @@ namespace {
|
||||
// Not all engines support it so can not use this API now.
|
||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||
|
||||
// Alias for better readability.
|
||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||
#define CONV_BACKWARD_INPUT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
||||
#define CONV_BACKWARD_WEIGHT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
cudnnDataType_t cudnn_type;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
std::array<int, MAX_NDIM> input_shape;
|
||||
std::array<int, MAX_NDIM> filter_shape;
|
||||
std::array<int, MAX_NDIM> weight_shape;
|
||||
std::array<int, MAX_NDIM> stride;
|
||||
std::array<int, MAX_NDIM> padding_lo;
|
||||
std::array<int, MAX_NDIM> padding_hi;
|
||||
std::array<int, MAX_NDIM> stride;
|
||||
std::array<int, MAX_NDIM> dilation;
|
||||
int groups;
|
||||
bool flip;
|
||||
uint8_t input_alignment;
|
||||
uint8_t filter_alignment;
|
||||
uint8_t weight_alignment;
|
||||
uint8_t output_alignment;
|
||||
};
|
||||
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
||||
/* capacity */ 128);
|
||||
static LRUBytesKeyCache<
|
||||
ConvCacheKey,
|
||||
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
||||
cache(/* capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@@ -162,17 +170,17 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
const_cast<void*>(in.data<void>()),
|
||||
const_cast<void*>(wt.data<void>()),
|
||||
out.data<void>(),
|
||||
x.data<void>(),
|
||||
w.data<void>(),
|
||||
y.data<void>(),
|
||||
};
|
||||
|
||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||
@@ -212,46 +220,154 @@ bool execute_plan(
|
||||
|
||||
bool try_engines(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const ConvCacheKey& cache_key,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const std::string& op_graph_tag,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
if (execute_plan(encoder, plan, in, wt, out)) {
|
||||
conv_cache().emplace(cache_key, std::move(plan));
|
||||
if (execute_plan(encoder, plan, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(plan)));
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException&) {
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
auto get_conv_op_settings(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding_lo_,
|
||||
const std::vector<int>& padding_hi_,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
||||
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||
}
|
||||
return std::make_tuple(
|
||||
convert_vector<int64_t>(input_dilation),
|
||||
std::move(padding_lo),
|
||||
std::move(padding_hi),
|
||||
convert_vector<int64_t>(kernel_dilation));
|
||||
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
padding_hi = padding_lo;
|
||||
return std::make_tuple(
|
||||
convert_vector<int64_t>(kernel_dilation),
|
||||
std::move(padding_lo),
|
||||
std::move(padding_hi),
|
||||
convert_vector<int64_t>(kernel_strides));
|
||||
|
||||
} else {
|
||||
return std::make_tuple(
|
||||
convert_vector<int64_t>(kernel_strides),
|
||||
std::move(padding_lo),
|
||||
std::move(padding_hi),
|
||||
convert_vector<int64_t>(kernel_dilation));
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding_lo,
|
||||
const std::vector<int64_t>& padding_hi,
|
||||
const std::vector<int64_t>& dilation) {
|
||||
try {
|
||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: dtype_to_cudnn_type(dtype);
|
||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(compute_dtype)
|
||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||
.setNDims(stride.size())
|
||||
.setStrides(stride.size(), stride.data())
|
||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', x))
|
||||
.setwDesc(build_tensor('w', w))
|
||||
.setyDesc(build_tensor('y', y))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||
return cudnn_frontend::OperationGraphBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setOperationGraph(ops.size(), ops.data())
|
||||
.build();
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
||||
throw;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||
// eval_gpu, with cost of possible redundant copies.
|
||||
std::tuple<array, array, array> prepare_args(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array in,
|
||||
array wt,
|
||||
array out,
|
||||
Stream s) {
|
||||
// Transpose the args depending on the backend type.
|
||||
// TODO: Handle groups.
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
in = swapaxes_in_eval(in, 0, -1);
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
// Create a contiguous array that shares the data with |out|, but with dim
|
||||
// C_in and C_out swapped.
|
||||
Shape shape(out.shape());
|
||||
std::swap(shape.front(), shape.back());
|
||||
Strides strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 2; i >= 0; --i) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
array intermediate(std::move(shape), out.dtype(), nullptr, {});
|
||||
intermediate.copy_shared_buffer(
|
||||
out, std::move(strides), {true, true, false}, out.data_size());
|
||||
out = intermediate;
|
||||
}
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// cuDNN requires contiguous input.
|
||||
// TODO: Handle NCHW format specially.
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
@@ -261,80 +377,170 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.add_temporary(wt);
|
||||
}
|
||||
|
||||
return {std::move(in), std::move(wt), std::move(out)};
|
||||
}
|
||||
|
||||
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
||||
inline std::tuple<array&, array&, array&> dispatch_args(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& in,
|
||||
array& wt,
|
||||
array& out) {
|
||||
switch (backend_type) {
|
||||
case CONV_BACKWARD_INPUT:
|
||||
return {out, wt, in};
|
||||
case CONV_BACKWARD_WEIGHT:
|
||||
return {in, out, wt};
|
||||
default:
|
||||
return {in, wt, out};
|
||||
}
|
||||
}
|
||||
|
||||
// Register inputs and outputs before actually running conv op. Can only be
|
||||
// called once per eval_gpu.
|
||||
void register_args(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& in,
|
||||
array& wt,
|
||||
array& intermediate_out,
|
||||
array& final_out) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
encoder.set_output_array(final_out);
|
||||
|
||||
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
||||
if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
// Turn |out| into a strided array, which will have C_in and C_out swapped
|
||||
// in vjp and the final |grad_weight| will then be contiguous.
|
||||
Strides strides = intermediate_out.strides();
|
||||
std::swap(strides.front(), strides.back());
|
||||
final_out.copy_shared_buffer(
|
||||
intermediate_out,
|
||||
std::move(strides),
|
||||
{false, false, false},
|
||||
intermediate_out.data_size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
if (out_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
array out = out_;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
Dtype dtype = out.dtype();
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
backend_type,
|
||||
cudnn_type,
|
||||
dtype_to_cudnn_type(dtype),
|
||||
fixed_vector(in.shape()),
|
||||
fixed_vector(wt.shape()),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(padding_lo_),
|
||||
fixed_vector(padding_hi_),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(kernel_dilation_),
|
||||
groups_,
|
||||
flip_,
|
||||
get_alignment(in),
|
||||
get_alignment(wt),
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||
auto& [backend_type, plan] = it->second;
|
||||
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Build operation graph.
|
||||
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: cudnn_type;
|
||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||
// convolution, so we make a best guess and then try.
|
||||
std::vector<cudnnBackendDescriptorType_t> try_backends;
|
||||
if (flip_) {
|
||||
// When weight is flipped, we assume it is backward input convolution.
|
||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||
} else {
|
||||
// Otherwise it could be backward weight convolution or forward convolution,
|
||||
// mathematically there is no difference so we have to use heuristics.
|
||||
// Empirically backward convolutions have large kernel dimensions, and
|
||||
// usually have |in| and |wt| transposed.
|
||||
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
|
||||
wt.shape(2) > out.shape(2)) {
|
||||
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
|
||||
} else {
|
||||
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
|
||||
}
|
||||
}
|
||||
|
||||
auto stride = convert_vector<int64_t>(kernel_strides_);
|
||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
||||
// Try to build op graph.
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||
for (auto try_backend : try_backends) {
|
||||
auto [in_copy, wt_copy, out_copy] =
|
||||
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||
try_backend,
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_);
|
||||
op_graph = build_op_graph(
|
||||
encoder,
|
||||
try_backend,
|
||||
dtype,
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
stride,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
dilation);
|
||||
if (op_graph) {
|
||||
backend_type = try_backend;
|
||||
in = std::move(in_copy);
|
||||
wt = std::move(wt_copy);
|
||||
out = std::move(out_copy);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!op_graph) {
|
||||
throw std::runtime_error("[conv] Can not build op graph.");
|
||||
}
|
||||
|
||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(compute_data_type)
|
||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||
.setNDims(stride.size())
|
||||
.setStrides(stride.size(), stride.data())
|
||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', in))
|
||||
.setwDesc(build_tensor('w', wt))
|
||||
.setyDesc(build_tensor('y', out))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setOperationGraph(ops.size(), ops.data())
|
||||
.build();
|
||||
// Get ready to execute the graph.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
auto op_graph_tag = op_graph.getTag();
|
||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
auto tag = op_graph->getTag();
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Unable to find an engine for convolution.");
|
||||
throw std::runtime_error("[conv] Unable to find a working engine.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -269,6 +269,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
|
||||
@@ -36,18 +36,15 @@ void eval(array& arr) {
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
// Except for the donated one.
|
||||
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
encoder.maybe_commit();
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
@@ -42,6 +44,7 @@ inline array ensure_row_contiguous_matrix(
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include <cub/device/device_segmented_sort.cuh>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -27,29 +26,6 @@ struct ModOp {
|
||||
}
|
||||
};
|
||||
|
||||
// We can not use any op in eval, make an utility.
|
||||
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
// TODO: Share the code with Transpose::eval.
|
||||
Shape shape(axes.size());
|
||||
Strides strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous) {
|
||||
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
}
|
||||
array out(shape, in.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
struct OffsetTransform {
|
||||
int nsort;
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ void ternary_op_gpu(
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("select::eval_gpu");
|
||||
nvtx3::scoped_range r("Select::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Pad::eval_gpu");
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -18,21 +18,12 @@ cuda_skip = {
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
# FFTs NYI
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
|
||||
Reference in New Issue
Block a user