mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 01:06:43 +08:00
[CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN * Initial implementation * Remove backend apis * Fix recording cudnn conv * More unused backend apis * Fix C++ conv tests * include cudnn as python dep * Install libcudnn9-dev-cuda-12 in CI * cudnn only accepts contiguous inputs * Switch to backend apis * Plan needs to be kept alive * Turn off tf32 * Add cache * Test the native cuda graph api * Set cudnn stream before execution * Make LRUCache more like a normal container * Do error check for cublas handle * Zero-initilizing array * Use tf32 for conv * Skip TestConv.test_torch_conv_2D test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
70dc336785
commit
6f5874a2f2
@ -216,6 +216,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
python3 -m venv env
|
python3 -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
@ -385,7 +386,7 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt install cuda-toolkit-12-9
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install zip
|
sudo apt-get install zip
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
|
@ -15,6 +15,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
@ -131,6 +132,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
|||||||
# Use NVRTC and driver APIs.
|
# Use NVRTC and driver APIs.
|
||||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||||
|
|
||||||
|
# Use the frontend APIs of cuDNN.
|
||||||
|
FetchContent_Declare(
|
||||||
|
cudnn
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
|
GIT_TAG v1.12.1
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||||
|
FetchContent_MakeAvailable(cudnn)
|
||||||
|
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||||
|
# Link with the actual cuDNN libraries.
|
||||||
|
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||||
|
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
340
mlx/backend/cuda/conv.cpp
Normal file
340
mlx/backend/cuda/conv.cpp
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
// cudnn_frontend.h redefines this macro.
|
||||||
|
#undef CHECK_CUDA_ERROR
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <cudnn_frontend_find_plan.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Not all engines support it so can not use this API now.
|
||||||
|
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||||
|
|
||||||
|
struct ConvCacheKey {
|
||||||
|
int device_id;
|
||||||
|
cudnnBackendDescriptorType_t backend_type;
|
||||||
|
cudnnDataType_t cudnn_type;
|
||||||
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
|
std::array<int, MAX_NDIM> filter_shape;
|
||||||
|
std::array<int, MAX_NDIM> padding_lo;
|
||||||
|
std::array<int, MAX_NDIM> padding_hi;
|
||||||
|
std::array<int, MAX_NDIM> stride;
|
||||||
|
std::array<int, MAX_NDIM> dilation;
|
||||||
|
int groups;
|
||||||
|
uint8_t input_alignment;
|
||||||
|
uint8_t filter_alignment;
|
||||||
|
uint8_t output_alignment;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto& conv_cache() {
|
||||||
|
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
||||||
|
/* capacity */ 128);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||||
|
return std::vector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
||||||
|
if (vec.size() > MAX_NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
|
}
|
||||||
|
std::array<T, MAX_NDIM> result = {};
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
auto strides = convert_vector<int64_t>(x.strides());
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(shape, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return CUDNN_DATA_INT8;
|
||||||
|
case int32:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case uint8:
|
||||||
|
return CUDNN_DATA_UINT8;
|
||||||
|
case float16:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDNN_DATA_BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case float64:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint8_t get_alignment(const array& x) {
|
||||||
|
uint8_t alignment = 1;
|
||||||
|
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||||
|
for (; alignment < 32; alignment *= 2) {
|
||||||
|
if (address % (alignment * 2)) {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||||
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(shape.size(), shape.data())
|
||||||
|
.setStrides(strides.size(), strides.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(get_alignment(x))
|
||||||
|
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
|
bool use_fallback = false) {
|
||||||
|
cudnn_frontend::GeneratorSource source;
|
||||||
|
if (use_fallback) {
|
||||||
|
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setOperation(backend_type)
|
||||||
|
.build();
|
||||||
|
return fallback.getFallbackList();
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||||
|
.build();
|
||||||
|
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||||
|
auto configs = generator.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool execute_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out) {
|
||||||
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
|
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||||
|
|
||||||
|
int64_t uids[3] = {'x', 'w', 'y'};
|
||||||
|
void* data_ptrs[3] = {
|
||||||
|
const_cast<void*>(in.data<void>()),
|
||||||
|
const_cast<void*>(wt.data<void>()),
|
||||||
|
out.data<void>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||||
|
.setWorkspacePointer(workspace.data<void>())
|
||||||
|
.setDataPointers(3, data_ptrs)
|
||||||
|
.setUids(3, uids)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||||
|
cudaGraph_t graph;
|
||||||
|
cudaGraphCreate(&graph, 0);
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||||
|
if (cudnnBackendPopulateCudaGraph(
|
||||||
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
#else
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
if (cudnnBackendExecute(
|
||||||
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
// Discard the captured graph when failed.
|
||||||
|
capture.discard = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool try_engines(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::EngineConfigList& configs,
|
||||||
|
const ConvCacheKey& cache_key,
|
||||||
|
const std::string& op_graph_tag,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out) {
|
||||||
|
for (auto& config : configs) {
|
||||||
|
try {
|
||||||
|
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setEngineConfig(config, op_graph_tag)
|
||||||
|
.build();
|
||||||
|
if (execute_plan(encoder, plan, in, wt, out)) {
|
||||||
|
conv_cache().emplace(cache_key, std::move(plan));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} catch (cudnn_frontend::cudnnException&) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
array in = inputs[0];
|
||||||
|
array wt = inputs[1];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// cuDNN requires contiguous input.
|
||||||
|
// TODO: Handle NCHW format specially.
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_input_array(wt);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||||
|
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
ConvCacheKey cache_key{
|
||||||
|
encoder.device().cuda_device(),
|
||||||
|
backend_type,
|
||||||
|
cudnn_type,
|
||||||
|
fixed_vector(in.shape()),
|
||||||
|
fixed_vector(wt.shape()),
|
||||||
|
fixed_vector(padding_lo_),
|
||||||
|
fixed_vector(padding_hi_),
|
||||||
|
fixed_vector(kernel_strides_),
|
||||||
|
fixed_vector(kernel_dilation_),
|
||||||
|
groups_,
|
||||||
|
get_alignment(in),
|
||||||
|
get_alignment(wt),
|
||||||
|
get_alignment(out)};
|
||||||
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
|
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||||
|
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build operation graph.
|
||||||
|
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
||||||
|
? CUDNN_DATA_FLOAT
|
||||||
|
: cudnn_type;
|
||||||
|
|
||||||
|
auto stride = convert_vector<int64_t>(kernel_strides_);
|
||||||
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||||
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||||
|
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
||||||
|
|
||||||
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||||
|
.setDataType(compute_data_type)
|
||||||
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||||
|
.setNDims(stride.size())
|
||||||
|
.setStrides(stride.size(), stride.data())
|
||||||
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||||
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||||
|
.setDilation(dilation.size(), dilation.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
|
.setxDesc(build_tensor('x', in))
|
||||||
|
.setwDesc(build_tensor('w', wt))
|
||||||
|
.setyDesc(build_tensor('y', out))
|
||||||
|
.setcDesc(conv_desc)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||||
|
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setOperationGraph(ops.size(), ops.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Try to run plans based on heuristics.
|
||||||
|
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||||
|
auto op_graph_tag = op_graph.getTag();
|
||||||
|
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Then try fallback plans.
|
||||||
|
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||||
|
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unable to find an engine for convolution.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -9,12 +9,23 @@
|
|||||||
#include <future>
|
#include <future>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||||
// This should be less than 255
|
// This should be less than 255
|
||||||
constexpr int default_max_nodes_per_graph = 20;
|
constexpr int default_max_nodes_per_graph = 20;
|
||||||
|
|
||||||
|
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
|
if (err != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int cuda_graph_cache_size() {
|
int cuda_graph_cache_size() {
|
||||||
static int cache_size = []() {
|
static int cache_size = []() {
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||||
@ -22,7 +33,7 @@ int cuda_graph_cache_size() {
|
|||||||
return cache_size;
|
return cache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cu {
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
@ -40,11 +51,14 @@ Device::Device(int device) : device_(device) {
|
|||||||
}
|
}
|
||||||
// The cublasLt handle is used by matmul.
|
// The cublasLt handle is used by matmul.
|
||||||
make_current();
|
make_current();
|
||||||
cublasLtCreate(<_);
|
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||||
|
// The cudnn handle is used by Convolution.
|
||||||
|
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
cublasLtDestroy(lt_);
|
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
@ -66,29 +80,36 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
|
enc.device().make_current();
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||||
|
if (discard) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and add as single kernel node when possible.
|
||||||
size_t num_nodes;
|
size_t num_nodes;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||||
if (num_nodes == 1) {
|
if (num_nodes == 1) {
|
||||||
cudaGraphNode_t captured_node;
|
cudaGraphNode_t captured_node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||||
CUDA_KERNEL_NODE_PARAMS params;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||||
cudaGraphNode_t node;
|
if (type == cudaGraphNodeTypeKernel) {
|
||||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
CUDA_KERNEL_NODE_PARAMS params;
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||||
} else {
|
enc.add_kernel_node(params);
|
||||||
cudaGraphNode_t node;
|
return;
|
||||||
CHECK_CUDA_ERROR(
|
}
|
||||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
|
||||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
// Otherwise add the captured graph as subgraph.
|
||||||
|
enc.add_graph_node(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||||
@ -221,10 +242,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
cudaGraphNode_t node;
|
add_kernel_node(kernel_params);
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(
|
void CommandEncoder::add_kernel_node(
|
||||||
@ -241,12 +259,27 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
CUgraphNode node;
|
add_kernel_node(kernel_params);
|
||||||
CHECK_CUDA_ERROR(
|
}
|
||||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
|
||||||
|
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
|
CUgraphNode node;
|
||||||
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
cudaGraphNode_t node;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
if (!temporaries_.empty()) {
|
if (!temporaries_.empty()) {
|
||||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||||
@ -331,6 +364,4 @@ CommandEncoder& get_command_encoder(Stream s) {
|
|||||||
return device(s.device).get_command_encoder(s);
|
return device(s.device).get_command_encoder(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace mlx::core::cu
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
|
#include <cudnn.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -21,6 +22,7 @@ class CommandEncoder {
|
|||||||
~CaptureContext();
|
~CaptureContext();
|
||||||
cudaGraph_t graph;
|
cudaGraph_t graph;
|
||||||
CommandEncoder& enc;
|
CommandEncoder& enc;
|
||||||
|
bool discard{false};
|
||||||
};
|
};
|
||||||
struct ConcurrentContext {
|
struct ConcurrentContext {
|
||||||
ConcurrentContext(CommandEncoder& enc);
|
ConcurrentContext(CommandEncoder& enc);
|
||||||
@ -65,6 +67,11 @@ class CommandEncoder {
|
|||||||
void
|
void
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||||
|
|
||||||
|
// Low-level graph helpers.
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
temporaries_.push_back(arr.data_shared_ptr());
|
temporaries_.push_back(arr.data_shared_ptr());
|
||||||
}
|
}
|
||||||
@ -73,6 +80,10 @@ class CommandEncoder {
|
|||||||
void maybe_commit();
|
void maybe_commit();
|
||||||
void commit();
|
void commit();
|
||||||
|
|
||||||
|
Device& device() {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
|
||||||
CudaStream& stream() {
|
CudaStream& stream() {
|
||||||
return stream_;
|
return stream_;
|
||||||
}
|
}
|
||||||
@ -137,12 +148,16 @@ class Device {
|
|||||||
cublasLtHandle_t lt_handle() const {
|
cublasLtHandle_t lt_handle() const {
|
||||||
return lt_;
|
return lt_;
|
||||||
}
|
}
|
||||||
|
cudnnHandle_t cudnn_handle() const {
|
||||||
|
return cudnn_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_;
|
int device_;
|
||||||
int compute_capability_major_;
|
int compute_capability_major_;
|
||||||
int compute_capability_minor_;
|
int compute_capability_minor_;
|
||||||
cublasLtHandle_t lt_;
|
cublasLtHandle_t lt_;
|
||||||
|
cudnnHandle_t cudnn_;
|
||||||
std::unordered_map<int, CommandEncoder> encoders_;
|
std::unordered_map<int, CommandEncoder> encoders_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
146
mlx/backend/cuda/lru_cache.h
Normal file
146
mlx/backend/cuda/lru_cache.h
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename K,
|
||||||
|
typename V,
|
||||||
|
template <typename...> typename M = std::unordered_map>
|
||||||
|
class LRUCache {
|
||||||
|
public:
|
||||||
|
using value_type = std::pair<K, V>;
|
||||||
|
using list_type = std::list<value_type>;
|
||||||
|
using iterator = typename list_type::iterator;
|
||||||
|
using const_iterator = typename list_type::const_iterator;
|
||||||
|
using map_type = M<K, iterator>;
|
||||||
|
|
||||||
|
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return map_.size();
|
||||||
|
}
|
||||||
|
size_t capacity() const {
|
||||||
|
return capacity_;
|
||||||
|
}
|
||||||
|
bool empty() const {
|
||||||
|
return vlist_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(size_t new_capacity) {
|
||||||
|
capacity_ = new_capacity;
|
||||||
|
trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator begin() {
|
||||||
|
return vlist_.begin();
|
||||||
|
}
|
||||||
|
const_iterator begin() const {
|
||||||
|
return vlist_.begin();
|
||||||
|
}
|
||||||
|
iterator end() {
|
||||||
|
return vlist_.end();
|
||||||
|
}
|
||||||
|
const_iterator end() const {
|
||||||
|
return vlist_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
map_.clear();
|
||||||
|
vlist_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator find(const K& key) {
|
||||||
|
auto it = map_.find(key);
|
||||||
|
if (it == map_.end())
|
||||||
|
return end();
|
||||||
|
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
std::pair<iterator, bool> emplace(const K& key, U&& value) {
|
||||||
|
auto it = map_.find(key);
|
||||||
|
if (it != map_.end()) {
|
||||||
|
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||||
|
return {it->second, false};
|
||||||
|
}
|
||||||
|
|
||||||
|
vlist_.emplace_front(key, std::forward<U>(value));
|
||||||
|
map_[key] = vlist_.begin();
|
||||||
|
|
||||||
|
trim();
|
||||||
|
|
||||||
|
return {vlist_.begin(), true};
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator erase(iterator pos) {
|
||||||
|
map_.erase(pos->first);
|
||||||
|
return vlist_.erase(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void trim() {
|
||||||
|
while (map_.size() > capacity_) {
|
||||||
|
auto last = std::prev(vlist_.end());
|
||||||
|
map_.erase(last->first);
|
||||||
|
vlist_.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
list_type vlist_;
|
||||||
|
map_type map_;
|
||||||
|
size_t capacity_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turn a POD struct into a container key by doing bytes compare.
|
||||||
|
template <typename T>
|
||||||
|
struct BytesKey {
|
||||||
|
T pod;
|
||||||
|
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||||
|
|
||||||
|
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||||
|
|
||||||
|
BytesKey(const BytesKey& other) {
|
||||||
|
memcpy(&pod, &other.pod, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
BytesKey(BytesKey&& other) {
|
||||||
|
memcpy(&pod, &other.pod, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const BytesKey& other) const {
|
||||||
|
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
|
||||||
|
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
|
||||||
|
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute hash according to the bytes value of T.
|
||||||
|
template <typename T>
|
||||||
|
struct BytesHash {
|
||||||
|
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||||
|
|
||||||
|
size_t operator()(const T& pod) const {
|
||||||
|
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
|
||||||
|
uint32_t value = 0x811C9DC5;
|
||||||
|
for (int i = 0; i < sizeof(T); ++i) {
|
||||||
|
value ^= ptr[i];
|
||||||
|
value *= 0x01000193;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename K, typename V>
|
||||||
|
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
|
||||||
|
|
||||||
|
template <typename K, typename V>
|
||||||
|
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -8,7 +8,6 @@
|
|||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
@ -18,16 +17,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
|
||||||
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
|
||||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
|
||||||
// TODO: Use cublasGetStatusString when it is widely available.
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CublasPreference {
|
struct CublasPreference {
|
||||||
CublasPreference(Device& device) {
|
CublasPreference(Device& device) {
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Convolution)
|
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
|
@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
|
|||||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void check_cuda_error(const char* name, cudaError_t err) {
|
void check_cuda_error(const char* name, cudaError_t err) {
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
@ -33,10 +34,12 @@ class CudaStream {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
void check_cuda_error(const char* name, CUresult err);
|
void check_cuda_error(const char* name, CUresult err);
|
||||||
|
|
||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
|
||||||
// Convert Dtype to CUDA C++ types.
|
// Convert Dtype to CUDA C++ types.
|
||||||
|
@ -16,7 +16,7 @@ rm "${repaired_wheel}"
|
|||||||
mlx_so="mlx/lib/libmlx.so"
|
mlx_so="mlx/lib/libmlx.so"
|
||||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||||
base="\$ORIGIN/../../nvidia"
|
base="\$ORIGIN/../../nvidia"
|
||||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib
|
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||||
python ../python/scripts/repair_record.py ${mlx_so}
|
python ../python/scripts/repair_record.py ${mlx_so}
|
||||||
|
|
||||||
|
@ -15,19 +15,12 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_asymmetric_padding",
|
|
||||||
"TestConv.test_basic_grad_shapes",
|
|
||||||
"TestConv.test_conv2d_unaligned_channels",
|
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
"TestConv.test_conv_1d_groups_flipped",
|
||||||
"TestConv.test_conv_general_flip_grad",
|
"TestConv.test_conv_general_flip_grad",
|
||||||
"TestConv.test_conv_groups_grad",
|
"TestConv.test_conv_groups_grad",
|
||||||
"TestConv.test_numpy_conv",
|
|
||||||
"TestConv.test_repeated_conv",
|
|
||||||
"TestConv.test_torch_conv_1D",
|
|
||||||
"TestConv.test_torch_conv_1D_grad",
|
"TestConv.test_torch_conv_1D_grad",
|
||||||
"TestConv.test_torch_conv_2D",
|
"TestConv.test_torch_conv_2D",
|
||||||
"TestConv.test_torch_conv_2D_grad",
|
"TestConv.test_torch_conv_2D_grad",
|
||||||
"TestConv.test_torch_conv_3D",
|
|
||||||
"TestConv.test_torch_conv_3D_grad",
|
"TestConv.test_torch_conv_3D_grad",
|
||||||
"TestConv.test_torch_conv_depthwise",
|
"TestConv.test_torch_conv_depthwise",
|
||||||
"TestConv.test_torch_conv_general",
|
"TestConv.test_torch_conv_general",
|
||||||
@ -40,10 +33,6 @@ cuda_skip = {
|
|||||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||||
"TestExportImport.test_export_conv",
|
|
||||||
"TestLayers.test_conv1d",
|
|
||||||
"TestLayers.test_conv2d",
|
|
||||||
"TestVmap.test_vmap_conv",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
1
setup.py
1
setup.py
@ -289,6 +289,7 @@ if __name__ == "__main__":
|
|||||||
install_requires += [
|
install_requires += [
|
||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
"nvidia-cudnn-cu12==12.9.*",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
Loading…
Reference in New Issue
Block a user