mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-07 11:48:55 +08:00
[CUDA] Tune ops per buffer based on device (#2761)
* tune ops per buffer based on device * tune memory limit as well * add tuning for spark
This commit is contained in:
@@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) {
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||
@@ -189,12 +190,41 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
}
|
||||
}
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER
|
||||
std::pair<int, int> get_graph_limits(Device& d) {
|
||||
auto cc =
|
||||
d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;
|
||||
int ops = 20;
|
||||
int mb = 100;
|
||||
switch (cc) {
|
||||
case 800: // A100
|
||||
ops = 20;
|
||||
mb = 400;
|
||||
break;
|
||||
case 900: // H100
|
||||
ops = 30;
|
||||
mb = 400;
|
||||
break;
|
||||
case 1000: // B200
|
||||
ops = 50;
|
||||
mb = 500;
|
||||
break;
|
||||
case 1210: // DGX Spark
|
||||
ops = 20;
|
||||
mb = 25;
|
||||
break;
|
||||
}
|
||||
return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
worker_(d),
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
|
||||
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
@@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
bytes_in_graph_ += arr.data_size();
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
}
|
||||
@@ -301,8 +332,9 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
int CommandEncoder::get_num_ops() {
|
||||
return node_count_;
|
||||
bool CommandEncoder::needs_commit() {
|
||||
return (node_count_ > max_ops_per_graph_) ||
|
||||
((bytes_in_graph_ >> 20) > max_mb_per_graph_);
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
@@ -365,6 +397,7 @@ void CommandEncoder::commit() {
|
||||
// Put completion handlers in a batch.
|
||||
worker_.commit(stream_);
|
||||
node_count_ = 0;
|
||||
bytes_in_graph_ = 0;
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
|
||||
@@ -84,7 +84,7 @@ class CommandEncoder {
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
int get_num_ops();
|
||||
bool needs_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
@@ -131,6 +131,9 @@ class CommandEncoder {
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
size_t bytes_in_graph_{0};
|
||||
int max_ops_per_graph_;
|
||||
int max_mb_per_graph_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
@@ -166,6 +169,7 @@ class Device {
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
std::string device_name_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
|
||||
@@ -11,9 +11,6 @@
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
@@ -53,8 +50,7 @@ void eval(array& arr) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
|
||||
if (encoder.get_num_ops() >=
|
||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
if (encoder.needs_commit()) {
|
||||
scheduler::notify_new_task(stream);
|
||||
encoder.add_completed_handler(
|
||||
[stream]() { scheduler::notify_task_completion(stream); });
|
||||
|
||||
@@ -382,11 +382,8 @@ MTL::CommandQueue* Device::get_queue(Stream stream) {
|
||||
|
||||
bool Device::command_buffer_needs_commit(int index) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.buffer_ops > max_ops_per_buffer_ ||
|
||||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return (stream.buffer_ops > max_ops_per_buffer_) ||
|
||||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
|
||||
Reference in New Issue
Block a user