mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Implement forward rms_norm with cuDNN
This commit is contained in:
@@ -17,6 +17,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
@@ -36,7 +37,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#include <cudnn_frontend.h>
|
|
||||||
#include <cudnn_frontend_find_plan.h>
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@@ -18,9 +14,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Not all engines support it so can not use this API now.
|
|
||||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
|
||||||
|
|
||||||
// Alias for better readability.
|
// Alias for better readability.
|
||||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||||
#define CONV_BACKWARD_INPUT \
|
#define CONV_BACKWARD_INPUT \
|
||||||
@@ -52,195 +45,6 @@ auto& conv_cache() {
|
|||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Vec>
|
|
||||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
|
||||||
return SmallVector<T>(vec.begin(), vec.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, template <typename U> class Vec>
|
|
||||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
|
||||||
if (vec.size() > MAX_NDIM) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
|
||||||
}
|
|
||||||
std::array<T, MAX_NDIM> result = {};
|
|
||||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto nhwc_to_nchw(const array& x) {
|
|
||||||
auto shape = convert_vector<int64_t>(x.shape());
|
|
||||||
shape.insert(shape.begin() + 1, shape.back());
|
|
||||||
shape.erase(shape.end() - 1);
|
|
||||||
auto strides = convert_vector<int64_t>(x.strides());
|
|
||||||
strides.insert(strides.begin() + 1, strides.back());
|
|
||||||
strides.erase(strides.end() - 1);
|
|
||||||
return std::make_tuple(std::move(shape), std::move(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case int8:
|
|
||||||
return CUDNN_DATA_INT8;
|
|
||||||
case int32:
|
|
||||||
return CUDNN_DATA_INT32;
|
|
||||||
case uint8:
|
|
||||||
return CUDNN_DATA_UINT8;
|
|
||||||
case float16:
|
|
||||||
return CUDNN_DATA_HALF;
|
|
||||||
case bfloat16:
|
|
||||||
return CUDNN_DATA_BFLOAT16;
|
|
||||||
case float32:
|
|
||||||
return CUDNN_DATA_FLOAT;
|
|
||||||
case float64:
|
|
||||||
return CUDNN_DATA_DOUBLE;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline uint8_t get_alignment(const array& x) {
|
|
||||||
uint8_t alignment = 1;
|
|
||||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
|
||||||
for (; alignment < 32; alignment *= 2) {
|
|
||||||
if (address % (alignment * 2)) {
|
|
||||||
return alignment;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return alignment;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
|
||||||
auto [shape, strides] = nhwc_to_nchw(x);
|
|
||||||
return cudnn_frontend::TensorBuilder()
|
|
||||||
.setDim(shape.size(), shape.data())
|
|
||||||
.setStrides(strides.size(), strides.data())
|
|
||||||
.setId(id)
|
|
||||||
.setAlignment(get_alignment(x))
|
|
||||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
Dtype dtype,
|
|
||||||
cudnn_frontend::OperationGraph& op_graph,
|
|
||||||
bool use_fallback = false) {
|
|
||||||
cudnn_frontend::GeneratorSource source;
|
|
||||||
if (use_fallback) {
|
|
||||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setOperation(backend_type)
|
|
||||||
.build();
|
|
||||||
return fallback.getFallbackList();
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
|
||||||
.build();
|
|
||||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
|
||||||
auto configs = generator.generate_engine_config(op_graph);
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList filtered_configs;
|
|
||||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
|
||||||
if (cudnn_frontend::hasNumericalNote<
|
|
||||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
|
||||||
dtype == float32 && !env::enable_tf32()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
return filtered_configs;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool execute_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
array& x,
|
|
||||||
array& w,
|
|
||||||
array& y) {
|
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
|
||||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
|
||||||
|
|
||||||
int64_t uids[3] = {'x', 'w', 'y'};
|
|
||||||
void* data_ptrs[3] = {
|
|
||||||
x.data<void>(),
|
|
||||||
w.data<void>(),
|
|
||||||
y.data<void>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
|
||||||
.setWorkspacePointer(workspace.data<void>())
|
|
||||||
.setDataPointers(3, data_ptrs)
|
|
||||||
.setUids(3, uids)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto handle = encoder.device().cudnn_handle();
|
|
||||||
cudnnSetStream(handle, encoder.stream());
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
|
||||||
CudaGraph graph(encoder.device());
|
|
||||||
if (cudnnBackendPopulateCudaGraph(
|
|
||||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
encoder.add_graph_node(graph);
|
|
||||||
#else
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
if (cudnnBackendExecute(
|
|
||||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
// Discard the captured graph when failed.
|
|
||||||
capture.discard = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool try_engines(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const ConvCacheKey& cache_key,
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
cudnn_frontend::EngineConfigList& configs,
|
|
||||||
const std::string& op_graph_tag,
|
|
||||||
array& x,
|
|
||||||
array& w,
|
|
||||||
array& y) {
|
|
||||||
for (auto& config : configs) {
|
|
||||||
try {
|
|
||||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
|
||||||
.setHandle(encoder.device().cudnn_handle())
|
|
||||||
.setEngineConfig(config, op_graph_tag)
|
|
||||||
.build();
|
|
||||||
if (execute_plan(encoder, plan, x, w, y)) {
|
|
||||||
conv_cache().emplace(
|
|
||||||
cache_key, std::make_pair(backend_type, std::move(plan)));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} catch (cudnn_frontend::cudnnException& error) {
|
|
||||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_conv_op_settings(
|
auto get_conv_op_settings(
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
array& x,
|
array& x,
|
||||||
@@ -285,7 +89,7 @@ auto get_conv_op_settings(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
@@ -311,9 +115,9 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
|||||||
.build();
|
.build();
|
||||||
|
|
||||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
.setxDesc(build_tensor('x', x))
|
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
||||||
.setwDesc(build_tensor('w', w))
|
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
||||||
.setyDesc(build_tensor('y', y))
|
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
||||||
.setcDesc(conv_desc)
|
.setcDesc(conv_desc)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@@ -475,12 +279,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
ConvCacheKey cache_key{
|
ConvCacheKey cache_key{
|
||||||
encoder.device().cuda_device(),
|
encoder.device().cuda_device(),
|
||||||
dtype_to_cudnn_type(dtype),
|
dtype_to_cudnn_type(dtype),
|
||||||
fixed_vector(in.shape()),
|
vector_key(in.shape()),
|
||||||
fixed_vector(wt.shape()),
|
vector_key(wt.shape()),
|
||||||
fixed_vector(kernel_strides_),
|
vector_key(kernel_strides_),
|
||||||
fixed_vector(padding_lo_),
|
vector_key(padding_lo_),
|
||||||
fixed_vector(padding_hi_),
|
vector_key(padding_hi_),
|
||||||
fixed_vector(kernel_dilation_),
|
vector_key(kernel_dilation_),
|
||||||
groups_,
|
groups_,
|
||||||
flip_,
|
flip_,
|
||||||
get_alignment(in),
|
get_alignment(in),
|
||||||
@@ -492,7 +296,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -534,7 +338,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
padding_hi_,
|
padding_hi_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_);
|
input_dilation_);
|
||||||
op_graph = build_op_graph(
|
op_graph = build_conv_op_graph(
|
||||||
encoder,
|
encoder,
|
||||||
try_backend,
|
try_backend,
|
||||||
dtype,
|
dtype,
|
||||||
@@ -557,22 +361,21 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
throw std::runtime_error("[conv] Can not build op graph.");
|
throw std::runtime_error("[conv] Can not build op graph.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get ready to execute the graph.
|
// Setup inputs and outputs.
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
|
||||||
// Try to run plans based on heuristics.
|
// Find a plan for the graph and execute it.
|
||||||
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
auto tag = op_graph->getTag();
|
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||||
|
if (!plan) {
|
||||||
|
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
||||||
|
}
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
return;
|
throw std::runtime_error("[conv] Failed to run execution plan.");
|
||||||
}
|
}
|
||||||
// Then try fallback plans.
|
conv_cache().emplace(
|
||||||
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("[conv] Unable to find a working engine.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
252
mlx/backend/cuda/cudnn_utils.cpp
Normal file
252
mlx/backend/cuda/cudnn_utils.cpp
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Create a cudnn tensor descriptor.
|
||||||
|
template <typename Vec>
|
||||||
|
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
||||||
|
int64_t id,
|
||||||
|
const array& x,
|
||||||
|
const Vec& shape,
|
||||||
|
const Vec& strides) {
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(shape.size(), shape.data())
|
||||||
|
.setStrides(strides.size(), strides.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(get_alignment(x))
|
||||||
|
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
|
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||||
|
assert(shape.size() >= 3);
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& x) {
|
||||||
|
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return available engines for a |op_graph|.
|
||||||
|
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
|
bool use_fallback = true) {
|
||||||
|
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
||||||
|
sources.push_back([](auto& op_graph) {
|
||||||
|
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||||
|
.build();
|
||||||
|
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||||
|
});
|
||||||
|
if (use_fallback) {
|
||||||
|
sources.push_back([&backend_type](auto& op_graph) {
|
||||||
|
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setOperation(backend_type)
|
||||||
|
.build();
|
||||||
|
return fallback.getFallbackList();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto configs =
|
||||||
|
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
||||||
|
.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take |engine_configs| and |op_graph| and find a working execution plans
|
||||||
|
// from them.
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan>
|
||||||
|
find_cudnn_plan_from_engine_configs(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
const cudnn_frontend::EngineConfigList& engine_configs,
|
||||||
|
const cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto op_graph_tag = op_graph.getTag();
|
||||||
|
for (const auto& config : engine_configs) {
|
||||||
|
try {
|
||||||
|
return cudnn_frontend::ExecutionPlanBuilder()
|
||||||
|
.setHandle(handle)
|
||||||
|
.setEngineConfig(config, op_graph_tag)
|
||||||
|
.build();
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare workspace and args to execute plan.
|
||||||
|
template <typename F>
|
||||||
|
bool prepare_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs,
|
||||||
|
F&& execute) {
|
||||||
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
|
array workspace(
|
||||||
|
workspace_size > 0 ? allocator::malloc(workspace_size)
|
||||||
|
: allocator::Buffer(nullptr),
|
||||||
|
{workspace_size},
|
||||||
|
uint8);
|
||||||
|
|
||||||
|
auto args = cudnn_frontend::VariantPackBuilder()
|
||||||
|
.setWorkspacePointer(workspace.data<void>())
|
||||||
|
.setDataPointers(num_args, data_ptrs)
|
||||||
|
.setUids(num_args, uids)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
return build_cudnn_tensor(id, x, shape, x.strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||||
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
||||||
|
if (x.ndim() == 0) {
|
||||||
|
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||||
|
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 1) {
|
||||||
|
int64_t s = x.shape(0);
|
||||||
|
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
||||||
|
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 2) {
|
||||||
|
int64_t s = x.strides(0);
|
||||||
|
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||||
|
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||||
|
return build_cudnn_tensor(id, x, shape, strides);
|
||||||
|
}
|
||||||
|
if (x.ndim() == 3 || x.ndim() == 4) {
|
||||||
|
return build_cudnn_tensor_nchw(id, x);
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
||||||
|
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(scalar_dims.size(), scalar_dims.data())
|
||||||
|
.setStrides(scalar_dims.size(), scalar_dims.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(16)
|
||||||
|
.setDataType(dtype_to_cudnn_type(dtype))
|
||||||
|
.setByValue(true)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||||
|
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool encode_cudnn_plan_with_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs) {
|
||||||
|
return prepare_cudnn_plan(
|
||||||
|
encoder,
|
||||||
|
plan,
|
||||||
|
num_args,
|
||||||
|
uids,
|
||||||
|
data_ptrs,
|
||||||
|
[&](auto handle, auto plan, auto args) {
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
||||||
|
// Discard the captured graph when failed.
|
||||||
|
capture.discard = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
bool encode_cudnn_plan_with_graph_api(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs) {
|
||||||
|
return prepare_cudnn_plan(
|
||||||
|
encoder,
|
||||||
|
plan,
|
||||||
|
num_args,
|
||||||
|
uids,
|
||||||
|
data_ptrs,
|
||||||
|
[&](auto handle, auto plan, auto args) {
|
||||||
|
if (!graph) {
|
||||||
|
graph = CudaGraph(encoder.device());
|
||||||
|
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
164
mlx/backend/cuda/cudnn_utils.h
Normal file
164
mlx/backend/cuda/cudnn_utils.h
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <cudnn_frontend_find_plan.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
class CommandEncoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return pointer alignment of |x|'s data.
|
||||||
|
inline uint8_t get_alignment(const array& x) {
|
||||||
|
uint8_t alignment = 1;
|
||||||
|
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||||
|
for (; alignment < 32; alignment *= 2) {
|
||||||
|
if (address % (alignment * 2)) {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the type of elements in |vec| to |T|.
|
||||||
|
template <typename T, typename Vec>
|
||||||
|
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||||
|
return SmallVector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||||
|
//
|
||||||
|
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||||
|
// 1. The rest of array is filled with 0.
|
||||||
|
// 2. This util can be used in .cpp files.
|
||||||
|
template <typename T, template <typename U> class Vec>
|
||||||
|
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||||
|
if (vec.size() > MAX_NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
|
}
|
||||||
|
std::array<T, MAX_NDIM> result = {};
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers used by get_data_ptrs to get pointers.
|
||||||
|
inline void* get_data_ptr(const array& arr) {
|
||||||
|
return const_cast<void*>(arr.data<void>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||||
|
inline void* get_data_ptr(T& scalar) {
|
||||||
|
return &scalar;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an array filled with data pointers of args.
|
||||||
|
template <typename... Args>
|
||||||
|
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
||||||
|
return {get_data_ptr(args)...};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map dtype to cudnn data type.
|
||||||
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return CUDNN_DATA_INT8;
|
||||||
|
case int32:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case uint8:
|
||||||
|
return CUDNN_DATA_UINT8;
|
||||||
|
case float16:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDNN_DATA_BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case float64:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
||||||
|
// from NHWC to NCHW.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
||||||
|
|
||||||
|
// Create a 4D scalar tensor descriptor, which is passed by value.
|
||||||
|
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
||||||
|
|
||||||
|
// Find a working plan for |op_graph|.
|
||||||
|
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph);
|
||||||
|
|
||||||
|
// Encode the plan to command buffer by capturing.
|
||||||
|
bool encode_cudnn_plan_with_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs);
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
||||||
|
// |graph| is empty it will be populated, otherwise it will be updated.
|
||||||
|
bool encode_cudnn_plan_with_graph_api(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
int num_args,
|
||||||
|
const int64_t* uids,
|
||||||
|
void** data_ptrs);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
||||||
|
template <typename... Args>
|
||||||
|
bool encode_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
std::initializer_list<int64_t> uids,
|
||||||
|
Args&... args) {
|
||||||
|
assert(uids.size() == sizeof...(args));
|
||||||
|
auto data_ptrs = get_data_ptrs(args...);
|
||||||
|
return encode_cudnn_plan_with_capturing(
|
||||||
|
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500
|
||||||
|
template <typename... Args>
|
||||||
|
bool encode_cudnn_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
CudaGraph& graph,
|
||||||
|
std::initializer_list<int64_t> uids,
|
||||||
|
Args&... args) {
|
||||||
|
assert(uids.size() == sizeof...(args));
|
||||||
|
auto data_ptrs = get_data_ptrs(args...);
|
||||||
|
return encode_cudnn_plan_with_graph_api(
|
||||||
|
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
145
mlx/backend/cuda/rms_norm.cpp
Normal file
145
mlx/backend/cuda/rms_norm.cpp
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct RMSNormCacheKey {
|
||||||
|
int device_id;
|
||||||
|
cudnnDataType_t cudnn_dtype;
|
||||||
|
std::array<int, MAX_NDIM> x_shape;
|
||||||
|
std::array<int64_t, MAX_NDIM> x_strides;
|
||||||
|
std::array<int, MAX_NDIM> scale_shape;
|
||||||
|
std::array<int64_t, MAX_NDIM> scale_strides;
|
||||||
|
uint8_t x_alignment;
|
||||||
|
uint8_t scale_alignment;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto& rms_norm_cache() {
|
||||||
|
static LRUBytesKeyCache<RMSNormCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
||||||
|
/* capacity */ 32);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto build_norm_op_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
cudnnBackendNormMode_t norm_mode,
|
||||||
|
cudnnBackendNormFwdPhase_t norm_phase,
|
||||||
|
const array& x,
|
||||||
|
const array& scale,
|
||||||
|
array& y) {
|
||||||
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
|
.setNormalizationMode(norm_mode)
|
||||||
|
.setNormFwdPhase(norm_phase)
|
||||||
|
.setxDesc(build_cudnn_tensor_4d_nchw('x', x))
|
||||||
|
.setScale(build_cudnn_tensor_4d_nchw('s', scale))
|
||||||
|
.setyDesc(build_cudnn_tensor_4d_nchw('y', y))
|
||||||
|
.setEpsilonTensor(build_cudnn_scalar_4d('e', float32))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||||
|
return cudnn_frontend::OperationGraphBuilder()
|
||||||
|
.setHandle(handle)
|
||||||
|
.setOperationGraph(ops.size(), ops.data())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool RMSNorm::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RMSNorm::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
const array& x = inputs[0];
|
||||||
|
array scale = inputs[1];
|
||||||
|
|
||||||
|
// cuDNN does not accept scalar as scale.
|
||||||
|
if (scale.ndim() == 0) {
|
||||||
|
array scale_copy({1, 1, 1, x.shape(-1)}, scale.dtype(), nullptr, {});
|
||||||
|
fill_gpu(scale, scale_copy, s);
|
||||||
|
encoder.add_temporary(scale_copy);
|
||||||
|
scale = scale_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Handle donations.
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
array& out = outputs[0];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(scale);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
RMSNormCacheKey cache_key{
|
||||||
|
encoder.device().cuda_device(),
|
||||||
|
dtype_to_cudnn_type(out.dtype()),
|
||||||
|
vector_key(x.shape()),
|
||||||
|
vector_key(x.strides()),
|
||||||
|
vector_key(scale.shape()),
|
||||||
|
vector_key(scale.strides()),
|
||||||
|
get_alignment(x),
|
||||||
|
get_alignment(scale)};
|
||||||
|
if (auto it = rms_norm_cache().find(cache_key);
|
||||||
|
it != rms_norm_cache().end()) {
|
||||||
|
auto& plan = it->second;
|
||||||
|
if (!encode_cudnn_plan(
|
||||||
|
encoder, plan, {'x', 's', 'y', 'e'}, x, scale, out, eps_)) {
|
||||||
|
throw std::runtime_error("[norm] Cached plan failed to execute.");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to build op graph.
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
auto backend_type = CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR;
|
||||||
|
auto norm_mode = CUDNN_RMS_NORM;
|
||||||
|
auto norm_phase = CUDNN_NORM_FWD_INFERENCE;
|
||||||
|
auto op_graph = build_norm_op_graph(
|
||||||
|
handle, backend_type, norm_mode, norm_phase, x, scale, out);
|
||||||
|
|
||||||
|
// Find a plan for the graph and execute it.
|
||||||
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
|
handle, backend_type, out.dtype(), op_graph);
|
||||||
|
if (!plan) {
|
||||||
|
throw std::runtime_error("[norm] Unable to find an execution plan.");
|
||||||
|
}
|
||||||
|
if (!encode_cudnn_plan(
|
||||||
|
encoder, *plan, {'x', 's', 'y', 'e'}, x, scale, out, eps_)) {
|
||||||
|
throw std::runtime_error("[conv] Failed to run execution plan.");
|
||||||
|
}
|
||||||
|
rms_norm_cache().emplace(cache_key, std::move(*plan));
|
||||||
|
}
|
||||||
|
|
||||||
|
void RMSNormVJP::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("RMSNormVJP::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
Reference in New Issue
Block a user