mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
nits + rebase
This commit is contained in:
parent
d2a826b3a4
commit
bb41f14add
@ -12,9 +12,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|||||||
|
|
||||||
|
|
||||||
class TestBase(mlx_tests.MLXTestCase):
|
class TestBase(mlx_tests.MLXTestCase):
|
||||||
def test_update(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_module_utilities(self):
|
def test_module_utilities(self):
|
||||||
m = nn.Sequential(
|
m = nn.Sequential(
|
||||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||||
@ -72,7 +69,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||||
|
|
||||||
def test_from_weights(self):
|
def test_load_from_weights(self):
|
||||||
m = nn.Linear(2, 2)
|
m = nn.Linear(2, 2)
|
||||||
|
|
||||||
# Too few weights
|
# Too few weights
|
||||||
|
Loading…
Reference in New Issue
Block a user