mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
nits + rebase
This commit is contained in:
parent
d2a826b3a4
commit
bb41f14add
@ -12,9 +12,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
def test_update(self):
|
||||
pass
|
||||
|
||||
def test_module_utilities(self):
|
||||
m = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||
@ -72,7 +69,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
def test_from_weights(self):
|
||||
def test_load_from_weights(self):
|
||||
m = nn.Linear(2, 2)
|
||||
|
||||
# Too few weights
|
||||
|
Loading…
Reference in New Issue
Block a user