cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format
This commit is contained in:
Awni Hannun
2025-06-22 14:32:10 -07:00
parent e76e9b87f0
commit 0d4a0e6531
36 changed files with 1461 additions and 1212 deletions

View File

@@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase):
scale = mx.array(2.0)
y = mx.load(save_file)
mx.eval(y)
mx.synchronize()
load_only = mx.get_peak_memory()
y = mx.load(save_file) * scale
mx.eval(y)
mx.synchronize()
load_with_binary = mx.get_peak_memory()
self.assertEqual(load_only, load_with_binary)