mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class * Implement forward rms_norm with cuDNN * Revert back to old rms norm kernel
This commit is contained in:
@@ -21,7 +21,7 @@ class CommandEncoder {
|
||||
struct CaptureContext {
|
||||
CaptureContext(CommandEncoder& enc);
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CudaGraph graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
@@ -115,7 +115,7 @@ class CommandEncoder {
|
||||
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
cudaGraph_t graph_;
|
||||
CudaGraph graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
|
||||
Reference in New Issue
Block a user