nits + rebase

This commit is contained in:
Awni Hannun 2024-01-01 21:28:58 -08:00
parent d2a826b3a4
commit bb41f14add

View File

@ -12,9 +12,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestBase(mlx_tests.MLXTestCase):
def test_update(self):
pass
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
@ -72,7 +69,7 @@ class TestBase(mlx_tests.MLXTestCase):
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_from_weights(self):
def test_load_from_weights(self):
m = nn.Linear(2, 2)
# Too few weights