mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Tune ops per buffer based on device (#2761)
* tune ops per buffer based on device * tune memory limit as well * add tuning for spark
This commit is contained in:
@@ -11,9 +11,6 @@
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
@@ -53,8 +50,7 @@ void eval(array& arr) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
|
||||
if (encoder.get_num_ops() >=
|
||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
if (encoder.needs_commit()) {
|
||||
scheduler::notify_new_task(stream);
|
||||
encoder.add_completed_handler(
|
||||
[stream]() { scheduler::notify_task_completion(stream); });
|
||||
|
||||
Reference in New Issue
Block a user