mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add top-level namespace access for gradient control functions
Users can now call mx.no_grad() and mx.enable_grad() directly without importing from mlx.autograd module, providing better PyTorch compatibility. - Import gradient context managers from mlx.autograd in transforms.cpp - Add fallback handling if autograd module import fails - Update tests to verify both mx.no_grad() and mlx.autograd.no_grad() work
This commit is contained in:

committed by
Yannick Müller

parent
323a8e958d
commit
df9ac5f2f9
@@ -1528,4 +1528,16 @@ void init_transforms(nb::module_& m) {
|
||||
tree_cache().clear();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
|
||||
// Import gradient control context managers from mlx.autograd
|
||||
// This allows mx.no_grad() instead of requiring mlx.autograd.no_grad()
|
||||
try {
|
||||
auto autograd = nb::module_::import_("mlx.autograd");
|
||||
m.attr("no_grad") = autograd.attr("no_grad");
|
||||
m.attr("enable_grad") = autograd.attr("enable_grad");
|
||||
// Note: set_grad_enabled is already exposed as a function above
|
||||
} catch (...) {
|
||||
// If import fails, these functions won't be available at top level
|
||||
// but the module will still load successfully
|
||||
}
|
||||
}
|
||||
|
@@ -831,6 +831,13 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
# Check that gradient mode is restored
|
||||
self.assertTrue(mx.is_grad_enabled())
|
||||
|
||||
# Test that mx.no_grad() also works (top-level import)
|
||||
with mx.no_grad():
|
||||
self.assertFalse(mx.is_grad_enabled())
|
||||
out, grad = mx.vjp(fun, x, mx.array(1.0))
|
||||
self.assertEqual(out.item(), 4.0)
|
||||
self.assertEqual(grad.item(), 0.0)
|
||||
|
||||
def test_no_grad_decorator(self):
|
||||
"""Test no_grad as a decorator."""
|
||||
from mlx.autograd import no_grad
|
||||
|
49
test_namespace_mx_no_grad.py
Normal file
49
test_namespace_mx_no_grad.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the python directory to the path to test our local build
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
|
||||
print("✓ Successfully imported mlx.core as mx")
|
||||
|
||||
# Test that no_grad is available at top level
|
||||
if hasattr(mx, "no_grad"):
|
||||
print("✓ mx.no_grad is available at top level")
|
||||
|
||||
# Test that it works
|
||||
x = mx.array(2.0)
|
||||
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
grad_fn = mx.value_and_grad(f)
|
||||
|
||||
# Test normal gradient
|
||||
y, dydx = grad_fn(x)
|
||||
print(f"✓ Normal gradient: f(2) = {y.item()}, df/dx = {dydx.item()}")
|
||||
|
||||
# Test with mx.no_grad()
|
||||
with mx.no_grad():
|
||||
y2, dydx2 = grad_fn(x)
|
||||
print(f"✓ With mx.no_grad(): f(2) = {y2.item()}, df/dx = {dydx2.item()}")
|
||||
|
||||
print("✓ mx.no_grad() works correctly at top level!")
|
||||
|
||||
else:
|
||||
print("✗ mx.no_grad is NOT available at top level")
|
||||
|
||||
# Test that enable_grad is also available
|
||||
if hasattr(mx, "enable_grad"):
|
||||
print("✓ mx.enable_grad is available at top level")
|
||||
else:
|
||||
print("✗ mx.enable_grad is NOT available at top level")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"✗ Failed to import mlx.core: {e}")
|
||||
except Exception as e:
|
||||
print(f"✗ Error during testing: {e}")
|
BIN
test_no_grad
BIN
test_no_grad
Binary file not shown.
Reference in New Issue
Block a user