diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 74e9bc761..eebd065de 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -12,9 +12,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten class TestBase(mlx_tests.MLXTestCase): - def test_update(self): - pass - def test_module_utilities(self): m = nn.Sequential( nn.Sequential(nn.Linear(2, 10), nn.relu), @@ -72,7 +69,7 @@ class TestBase(mlx_tests.MLXTestCase): eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) self.assertTrue(all(tree_flatten(eq_tree))) - def test_from_weights(self): + def test_load_from_weights(self): m = nn.Linear(2, 2) # Too few weights