mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Tune ops per buffer based on device (#2761)
* tune ops per buffer based on device * tune memory limit as well * add tuning for spark
This commit is contained in:
@@ -84,7 +84,7 @@ class CommandEncoder {
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
int get_num_ops();
|
||||
bool needs_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
@@ -131,6 +131,9 @@ class CommandEncoder {
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
size_t bytes_in_graph_{0};
|
||||
int max_ops_per_graph_;
|
||||
int max_mb_per_graph_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
@@ -166,6 +169,7 @@ class Device {
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
std::string device_name_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
|
||||
Reference in New Issue
Block a user