mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Add a multi optimizer (#1916)
This commit is contained in:

committed by
GitHub

parent
a0737273d3
commit
9680f72cca
@@ -39,6 +39,7 @@ def tree_equal(fn, *args):
|
||||
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
del optimizers_dict["MultiOptimizer"]
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
@@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase):
|
||||
grads = model.trainable_parameters()
|
||||
optimizer.update(model, grads)
|
||||
|
||||
def test_multi_optimizer(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(2, 2)
|
||||
self.drop = nn.Dropout(p=0.5)
|
||||
self.l2 = nn.Linear(2, 2)
|
||||
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
|
||||
|
||||
model = Model()
|
||||
optimizer = opt.MultiOptimizer(
|
||||
[opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
|
||||
[lambda name, weight: weight.ndim > 1],
|
||||
)
|
||||
optimizer.init(model.trainable_parameters())
|
||||
|
||||
self.assertEqual(len(optimizer.state["states"]), 2)
|
||||
|
||||
adam_states = tree_flatten(optimizer.state["states"][0])
|
||||
sgd_states = tree_flatten(optimizer.state["states"][1])
|
||||
self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
|
||||
self.assertFalse(any("bias" in k for k, v in adam_states))
|
||||
self.assertFalse(any("weight" in k for k, v in sgd_states))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -3,6 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
|
||||
@@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||
|
||||
def test_merge(self):
|
||||
t1 = {"a": 0}
|
||||
t2 = {"b": 1}
|
||||
t = mlx.utils.tree_merge(t1, t2)
|
||||
self.assertEqual({"a": 0, "b": 1}, t)
|
||||
with self.assertRaises(ValueError):
|
||||
mlx.utils.tree_merge(t1, t1)
|
||||
with self.assertRaises(ValueError):
|
||||
mlx.utils.tree_merge(t, t1)
|
||||
|
||||
mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||
mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||
mod = nn.Sequential(mod1, mod2)
|
||||
|
||||
params1 = {"layers": [mod1.parameters()]}
|
||||
params2 = {"layers": [None, mod2.parameters()]}
|
||||
params = mlx.utils.tree_merge(params1, params2)
|
||||
for (k1, v1), (k2, v2) in zip(
|
||||
mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters())
|
||||
):
|
||||
self.assertEqual(k1, k2)
|
||||
self.assertTrue(mx.array_equal(v1, v2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -345,6 +345,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
def test_vmap_inverse(self):
|
||||
mx.random.seed(42)
|
||||
a = mx.random.uniform(shape=(3, 4, 4))
|
||||
|
||||
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
|
||||
|
Reference in New Issue
Block a user