mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
190 lines
4.6 KiB
C++
190 lines
4.6 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/array.h"
|
|
#include "mlx/backend/cuda/allocator.h"
|
|
#include "mlx/backend/cuda/lru_cache.h"
|
|
#include "mlx/backend/cuda/worker.h"
|
|
#include "mlx/stream.h"
|
|
|
|
#include <cublasLt.h>
|
|
#include <cuda.h>
|
|
#include <cudnn.h>
|
|
#include <thrust/execution_policy.h>
|
|
|
|
#include <unordered_map>
|
|
|
|
namespace mlx::core::cu {
|
|
|
|
class CommandEncoder {
|
|
public:
|
|
struct CaptureContext {
|
|
CaptureContext(CommandEncoder& enc);
|
|
~CaptureContext();
|
|
CudaGraph graph;
|
|
CommandEncoder& enc;
|
|
bool discard{false};
|
|
};
|
|
struct ConcurrentContext {
|
|
ConcurrentContext(CommandEncoder& enc);
|
|
~ConcurrentContext();
|
|
CommandEncoder& enc;
|
|
};
|
|
|
|
explicit CommandEncoder(Device& d);
|
|
|
|
CommandEncoder(const CommandEncoder&) = delete;
|
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
|
|
|
CaptureContext capture_context() {
|
|
return CaptureContext{*this};
|
|
}
|
|
ConcurrentContext concurrent_context() {
|
|
return ConcurrentContext{*this};
|
|
}
|
|
|
|
void set_input_array(const array& arr);
|
|
void set_output_array(const array& arr);
|
|
|
|
template <typename F, typename... Params>
|
|
void add_kernel_node(
|
|
F* func,
|
|
dim3 grid_dim,
|
|
dim3 block_dim,
|
|
uint32_t smem_bytes,
|
|
Params&&... params) {
|
|
constexpr size_t num = sizeof...(Params);
|
|
void* ptrs[num];
|
|
size_t i = 0;
|
|
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
|
std::forward<Params>(params)),
|
|
...);
|
|
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
|
}
|
|
|
|
void add_kernel_node(
|
|
CUfunction func,
|
|
dim3 grid_dim,
|
|
dim3 block_dim,
|
|
uint32_t smem_bytes,
|
|
void** params);
|
|
|
|
void add_kernel_node(
|
|
void* func,
|
|
dim3 grid_dim,
|
|
dim3 block_dim,
|
|
uint32_t smem_bytes,
|
|
void** params);
|
|
|
|
void add_graph_node(cudaGraph_t child);
|
|
|
|
void add_temporary(const array& arr) {
|
|
temporaries_.push_back(arr.data_shared_ptr());
|
|
}
|
|
|
|
void add_completed_handler(std::function<void()> task);
|
|
bool needs_commit();
|
|
void commit();
|
|
|
|
Device& device() {
|
|
return device_;
|
|
}
|
|
|
|
CudaStream& stream() {
|
|
return stream_;
|
|
}
|
|
|
|
// Wait until kernels and completion handlers are finished
|
|
void synchronize();
|
|
|
|
private:
|
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
|
|
|
struct GraphNode {
|
|
cudaGraphNode_t node;
|
|
// K = kernel
|
|
// E = empty
|
|
// G* = subgraph (with metadata)
|
|
// Symbols ':', '-' are reserved as separators
|
|
std::string node_type;
|
|
std::string id;
|
|
};
|
|
|
|
void insert_graph_dependencies(GraphNode node);
|
|
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
|
|
|
Device& device_;
|
|
CudaStream stream_;
|
|
CudaGraph graph_;
|
|
Worker worker_;
|
|
char node_count_{0};
|
|
bool in_concurrent_{false};
|
|
std::vector<cudaGraphNode_t> from_nodes_;
|
|
std::vector<cudaGraphNode_t> to_nodes_;
|
|
std::string graph_nodes_key_;
|
|
std::string graph_deps_key_;
|
|
std::vector<GraphNode> concurrent_nodes_;
|
|
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
|
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
|
std::vector<std::uintptr_t> active_deps_;
|
|
std::vector<std::uintptr_t> active_outputs_;
|
|
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
|
size_t bytes_in_graph_{0};
|
|
bool is_graph_updatable_{true};
|
|
int max_ops_per_graph_;
|
|
int max_mb_per_graph_;
|
|
};
|
|
|
|
class Device {
|
|
public:
|
|
explicit Device(int device);
|
|
~Device();
|
|
|
|
Device(const Device&) = delete;
|
|
Device& operator=(const Device&) = delete;
|
|
|
|
// Make this device the current cuda device, this method is thread-safe.
|
|
void make_current();
|
|
|
|
CommandEncoder& get_command_encoder(Stream s);
|
|
|
|
int cuda_device() const {
|
|
return device_;
|
|
}
|
|
int compute_capability_major() const {
|
|
return compute_capability_major_;
|
|
}
|
|
int compute_capability_minor() const {
|
|
return compute_capability_minor_;
|
|
}
|
|
cublasLtHandle_t lt_handle() const {
|
|
return lt_;
|
|
}
|
|
cudnnHandle_t cudnn_handle() const {
|
|
return cudnn_;
|
|
}
|
|
|
|
private:
|
|
int device_;
|
|
int compute_capability_major_;
|
|
int compute_capability_minor_;
|
|
std::string device_name_;
|
|
cublasLtHandle_t lt_;
|
|
cudnnHandle_t cudnn_;
|
|
std::unordered_map<int, CommandEncoder> encoders_;
|
|
};
|
|
|
|
Device& device(mlx::core::Device device);
|
|
CommandEncoder& get_command_encoder(Stream s);
|
|
|
|
// Return an execution policy that does not sync for result.
|
|
// Note that not all thrust APIs support async policy, confirm before using.
|
|
inline auto thrust_policy(cudaStream_t stream) {
|
|
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
|
return thrust::cuda::par_nosync.on(stream);
|
|
}
|
|
|
|
} // namespace mlx::core::cu
|