Add top-level namespace access for gradient control functions

Users can now call mx.no_grad() and mx.enable_grad() directly without
importing from mlx.autograd module, providing better PyTorch compatibility.

- Import gradient context managers from mlx.autograd in transforms.cpp
- Add fallback handling if autograd module import fails
- Update tests to verify both mx.no_grad() and mlx.autograd.no_grad() work
This commit is contained in:
Yannick Muller
2025-08-03 00:46:27 -04:00
committed by Yannick Müller
parent 323a8e958d
commit df9ac5f2f9
4 changed files with 68 additions and 0 deletions

View File

@@ -1528,4 +1528,16 @@ void init_transforms(nb::module_& m) {
tree_cache().clear();
mx::detail::compile_clear_cache();
}));
// Import gradient control context managers from mlx.autograd
// This allows mx.no_grad() instead of requiring mlx.autograd.no_grad()
try {
auto autograd = nb::module_::import_("mlx.autograd");
m.attr("no_grad") = autograd.attr("no_grad");
m.attr("enable_grad") = autograd.attr("enable_grad");
// Note: set_grad_enabled is already exposed as a function above
} catch (...) {
// If import fails, these functions won't be available at top level
// but the module will still load successfully
}
}

View File

@@ -831,6 +831,13 @@ class TestAutograd(mlx_tests.MLXTestCase):
# Check that gradient mode is restored
self.assertTrue(mx.is_grad_enabled())
# Test that mx.no_grad() also works (top-level import)
with mx.no_grad():
self.assertFalse(mx.is_grad_enabled())
out, grad = mx.vjp(fun, x, mx.array(1.0))
self.assertEqual(out.item(), 4.0)
self.assertEqual(grad.item(), 0.0)
def test_no_grad_decorator(self):
"""Test no_grad as a decorator."""
from mlx.autograd import no_grad

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python3
import os
import sys
# Add the python directory to the path to test our local build
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
try:
import mlx.core as mx
print("✓ Successfully imported mlx.core as mx")
# Test that no_grad is available at top level
if hasattr(mx, "no_grad"):
print("✓ mx.no_grad is available at top level")
# Test that it works
x = mx.array(2.0)
def f(x):
return x * x
grad_fn = mx.value_and_grad(f)
# Test normal gradient
y, dydx = grad_fn(x)
print(f"✓ Normal gradient: f(2) = {y.item()}, df/dx = {dydx.item()}")
# Test with mx.no_grad()
with mx.no_grad():
y2, dydx2 = grad_fn(x)
print(f"✓ With mx.no_grad(): f(2) = {y2.item()}, df/dx = {dydx2.item()}")
print("✓ mx.no_grad() works correctly at top level!")
else:
print("✗ mx.no_grad is NOT available at top level")
# Test that enable_grad is also available
if hasattr(mx, "enable_grad"):
print("✓ mx.enable_grad is available at top level")
else:
print("✗ mx.enable_grad is NOT available at top level")
except ImportError as e:
print(f"✗ Failed to import mlx.core: {e}")
except Exception as e:
print(f"✗ Error during testing: {e}")

Binary file not shown.