mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			452 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			452 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023 Apple Inc.
 | 
						|
 | 
						|
import inspect
 | 
						|
import math
 | 
						|
import unittest
 | 
						|
from functools import partial
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx.nn as nn
 | 
						|
import mlx.optimizers as opt
 | 
						|
import mlx.utils
 | 
						|
import mlx_tests
 | 
						|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
 | 
						|
 | 
						|
 | 
						|
def get_all_optimizers():
 | 
						|
    classes = dict()
 | 
						|
    for name, obj in inspect.getmembers(opt):
 | 
						|
        if (
 | 
						|
            inspect.isclass(obj)
 | 
						|
            and issubclass(obj, opt.Optimizer)
 | 
						|
            and obj != opt.Optimizer
 | 
						|
        ):
 | 
						|
            classes[name] = obj
 | 
						|
    return classes
 | 
						|
 | 
						|
 | 
						|
def tree_equal(fn, *args):
 | 
						|
    return all(v for _, v in tree_flatten(tree_map(fn, *args)))
 | 
						|
 | 
						|
 | 
						|
optimizers_dict = get_all_optimizers()
 | 
						|
 | 
						|
 | 
						|
class TestOptimizers(mlx_tests.MLXTestCase):
 | 
						|
    def test_optimizer_state(self):
 | 
						|
        optim = opt.SGD(0.1)
 | 
						|
        optim.state["hello"] = "world"
 | 
						|
        self.assertEqual(optim.state["hello"], "world")
 | 
						|
 | 
						|
        optim.state = {0: 1}
 | 
						|
        self.assertEqual(optim.state, {0: 1})
 | 
						|
 | 
						|
    def test_optimizers(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        for optim_class in optimizers_dict.values():
 | 
						|
            optim = optim_class(0.1)
 | 
						|
            update = optim.apply_gradients(grads, params)
 | 
						|
            mx.eval(update)
 | 
						|
            equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
 | 
						|
            all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
 | 
						|
            self.assertTrue(all_equal)
 | 
						|
 | 
						|
    def test_types_conserved(self):
 | 
						|
        params = {"w": mx.ones((5, 5), mx.float16)}
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
        for optim_class in optimizers_dict.values():
 | 
						|
            optim = optim_class(0.1)
 | 
						|
            update = optim.apply_gradients(grads, params)
 | 
						|
            self.assertEqual(update["w"].dtype, mx.float16)
 | 
						|
 | 
						|
    def test_sgd(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Implicit init
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
        optim.apply_gradients(grads, params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
 | 
						|
        )
 | 
						|
 | 
						|
    def test_rmsprop(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.RMSprop(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Implicit init
 | 
						|
        alpha = 0.99
 | 
						|
        optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
 | 
						|
        optim.apply_gradients(grads, params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adagrad(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.Adagrad(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adadelta(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.AdaDelta(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adam(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
 | 
						|
            optim = optimizer(learning_rate=1e-2)
 | 
						|
            optim.init(params)
 | 
						|
            self.assertTrue(
 | 
						|
                tree_equal(
 | 
						|
                    lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                    params,
 | 
						|
                    optim.state,
 | 
						|
                )
 | 
						|
            )
 | 
						|
            self.assertTrue(
 | 
						|
                tree_equal(
 | 
						|
                    lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
 | 
						|
                    params,
 | 
						|
                    optim.state,
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
    def test_lion(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.Lion(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adafactor(self):
 | 
						|
        x = mx.zeros((5, 5))
 | 
						|
        params = {"x": x}
 | 
						|
        grad = {"x": mx.ones_like(x)}
 | 
						|
        optimizer = opt.Adafactor()
 | 
						|
        for _ in range(2):
 | 
						|
            xp = optimizer.apply_gradients(grad, params)
 | 
						|
            self.assertEqual(xp["x"].dtype, x.dtype)
 | 
						|
            self.assertEqual(xp["x"].shape, x.shape)
 | 
						|
 | 
						|
        x = mx.zeros((5, 5), mx.float16)
 | 
						|
        params = {"x": x}
 | 
						|
        grad = {"x": mx.ones_like(x)}
 | 
						|
        optimizer = opt.Adafactor()
 | 
						|
        for _ in range(2):
 | 
						|
            xp = optimizer.apply_gradients(grad, params)
 | 
						|
            self.assertEqual(xp["x"].dtype, x.dtype)
 | 
						|
            self.assertEqual(xp["x"].shape, x.shape)
 | 
						|
        self.assertEqual(optimizer.state["step"], 2)
 | 
						|
 | 
						|
    def test_compiled_optimizer(self):
 | 
						|
        model = nn.Linear(10, 10)
 | 
						|
        x = mx.random.uniform(shape=(2, 10))
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
 | 
						|
        orig_params = model.parameters()
 | 
						|
 | 
						|
        def loss(model, x):
 | 
						|
            return model(x).sum()
 | 
						|
 | 
						|
        # Uncompiled version
 | 
						|
        def step(x):
 | 
						|
            _, grad = nn.value_and_grad(model, loss)(model, x)
 | 
						|
            optim.update(model, grad)
 | 
						|
 | 
						|
        step(x)
 | 
						|
        uncompiled_params = model.parameters()
 | 
						|
 | 
						|
        # Pure version
 | 
						|
        def loss(params, x):
 | 
						|
            model.update(params)
 | 
						|
            return model(x).sum()
 | 
						|
 | 
						|
        model.update(orig_params)
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
 | 
						|
        @mx.compile
 | 
						|
        def step(params, opt_state, x):
 | 
						|
            grad = mx.grad(loss)(params, x)
 | 
						|
            optim.state = opt_state
 | 
						|
            params = optim.apply_gradients(grad, params)
 | 
						|
            return params, optim.state
 | 
						|
 | 
						|
        optim.init(model.parameters())
 | 
						|
        pure_params, _ = step(model.parameters(), optim.state, x)
 | 
						|
        self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
 | 
						|
        self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
 | 
						|
 | 
						|
        # Impure version
 | 
						|
        def loss(model, x):
 | 
						|
            return model(x).sum()
 | 
						|
 | 
						|
        model.update(orig_params)
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
        state = [model.state, optim.state]
 | 
						|
 | 
						|
        @partial(mx.compile, inputs=state, outputs=state)
 | 
						|
        def step(x):
 | 
						|
            _, grad = nn.value_and_grad(model, loss)(model, x)
 | 
						|
            optim.update(model, grad)
 | 
						|
 | 
						|
        step(x)
 | 
						|
        impure_params = model.parameters()
 | 
						|
        self.assertTrue(
 | 
						|
            mx.allclose(impure_params["weight"], uncompiled_params["weight"])
 | 
						|
        )
 | 
						|
        self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
 | 
						|
 | 
						|
    def test_update_lr_compiled(self):
 | 
						|
        params = {"w": mx.ones((5, 5))}
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
        optim = opt.SGD(-1.0)
 | 
						|
 | 
						|
        @partial(mx.compile, inputs=optim.state)
 | 
						|
        def update(grads):
 | 
						|
            return optim.apply_gradients(grads, params)
 | 
						|
 | 
						|
        result = update(grads)
 | 
						|
        self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
 | 
						|
        optim.learning_rate = -2.0
 | 
						|
        result = update(grads)
 | 
						|
        self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
 | 
						|
 | 
						|
 | 
						|
class TestSchedulers(unittest.TestCase):
 | 
						|
    def test_decay_lr(self):
 | 
						|
        for optim_class in optimizers_dict.values():
 | 
						|
            lr_schedule = opt.step_decay(1e-1, 0.9, 1)
 | 
						|
            optimizer = optim_class(learning_rate=lr_schedule)
 | 
						|
 | 
						|
            params = {"w": mx.ones((5, 5))}
 | 
						|
            grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
            for it in range(10):
 | 
						|
                optimizer.apply_gradients(grads, params)
 | 
						|
                expected_lr = 0.1 * (0.9**it)
 | 
						|
                self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
    def test_step_decay(self):
 | 
						|
        lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
 | 
						|
        lr = lr_schedule(2500)
 | 
						|
        expected_lr = 0.1 * (0.9**2)
 | 
						|
        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
    def test_exponential_decay(self):
 | 
						|
        lr_schedule = opt.exponential_decay(1e-1, 0.99)
 | 
						|
        lr = lr_schedule(10)
 | 
						|
        expected_lr = 0.1 * (0.99**10)
 | 
						|
        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
    def test_cosine_decay(self):
 | 
						|
        lr_schedule = opt.cosine_decay(0.1, 10)
 | 
						|
        lr = lr_schedule(4)
 | 
						|
        expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
 | 
						|
        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
        lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
 | 
						|
        lr = lr_schedule(9)
 | 
						|
        expected_end_lr = 0.05
 | 
						|
        self.assertGreater(lr, expected_end_lr)
 | 
						|
        lr = lr_schedule(20)
 | 
						|
        self.assertEqual(lr, expected_end_lr)
 | 
						|
 | 
						|
    def test_schedule_joiner(self):
 | 
						|
        boundaries = [2, 3, 4]
 | 
						|
        schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
 | 
						|
        with self.assertRaises(ValueError):
 | 
						|
            opt.schedulers.join_schedules(schedules, boundaries)
 | 
						|
        boundaries = [2, 4]
 | 
						|
        schedule = opt.schedulers.join_schedules(schedules, boundaries)
 | 
						|
        self.assertEqual(schedule(0).item(), 3)
 | 
						|
        self.assertEqual(schedule(1).item(), 3)
 | 
						|
        self.assertEqual(schedule(2).item(), 4)
 | 
						|
        self.assertEqual(schedule(3).item(), 4)
 | 
						|
        self.assertEqual(schedule(5).item(), 5)
 | 
						|
        self.assertEqual(schedule(7).item(), 5)
 | 
						|
 | 
						|
    def test_linear_warmup_with_cosine_decay(self):
 | 
						|
        warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
 | 
						|
        cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
 | 
						|
        cos_with_warmup = opt.schedulers.join_schedules(
 | 
						|
            [warmup_schedule, cosine_schedule], [101]
 | 
						|
        )
 | 
						|
        self.assertEqual(cos_with_warmup(0), 0.0)
 | 
						|
        self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
 | 
						|
        optimizer = opt.Adam(learning_rate=cos_with_warmup)
 | 
						|
        for _ in range(100):
 | 
						|
            optimizer.update({}, {})
 | 
						|
        self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
 | 
						|
        for _ in range(100):
 | 
						|
            optimizer.update({}, {})
 | 
						|
        expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
 | 
						|
        self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
 | 
						|
 | 
						|
    def test_compile_with_schedule(self):
 | 
						|
        lr_schedule = opt.exponential_decay(1e-1, 0.9)
 | 
						|
        optimizer = opt.SGD(learning_rate=lr_schedule)
 | 
						|
 | 
						|
        @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
 | 
						|
        def update():
 | 
						|
            optimizer.update({}, {})
 | 
						|
 | 
						|
        for step in range(5):
 | 
						|
            update()
 | 
						|
            self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
 | 
						|
 | 
						|
    def test_clip_grad_norm(self):
 | 
						|
        # Test with small gradients that do not require clipping
 | 
						|
        small_grads = {
 | 
						|
            "first": [mx.array([0.1, 0.2]), mx.array([0.1])],
 | 
						|
            "second": mx.array([0.3]),
 | 
						|
        }
 | 
						|
        max_norm = 10.0  # A large max_norm that shouldn't trigger clipping
 | 
						|
        clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),
 | 
						|
            "Gradients should not be modified when clipping is not necessary.",
 | 
						|
        )
 | 
						|
 | 
						|
        # Test with large gradients that require clipping
 | 
						|
        large_grads = {
 | 
						|
            "first": [mx.array([10, 20]), mx.array([10])],
 | 
						|
            "second": mx.array([30]),
 | 
						|
        }
 | 
						|
        max_norm = 1.0  # A small max_norm that should trigger clipping
 | 
						|
        clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)
 | 
						|
        # Correctly extract only the gradient values for norm calculation
 | 
						|
        clipped_values = [value for _, value in tree_flatten(clipped_grads)]
 | 
						|
        norm_of_clipped = mx.sqrt(
 | 
						|
            sum(mx.square(g).sum() for g in clipped_values)
 | 
						|
        ).item()
 | 
						|
        self.assertAlmostEqual(
 | 
						|
            norm_of_clipped,
 | 
						|
            max_norm,
 | 
						|
            places=6,
 | 
						|
            msg="Clipped gradients norm should be close to the specified max_norm.",
 | 
						|
        )
 | 
						|
 | 
						|
        # Ensures that the scaling was done correctly
 | 
						|
        scale = max_norm / total_norm
 | 
						|
        expected_grads = tree_map(lambda g: g * scale, large_grads)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads
 | 
						|
            ),
 | 
						|
            "Gradients were not scaled correctly during clipping.",
 | 
						|
        )
 | 
						|
 | 
						|
    def test_init_from_state(self):
 | 
						|
        class Model(nn.Module):
 | 
						|
            def __init__(self):
 | 
						|
                super().__init__()
 | 
						|
                self.l1 = nn.Linear(2, 2)
 | 
						|
                self.drop = nn.Dropout(p=0.5)
 | 
						|
                self.l2 = nn.Linear(2, 2)
 | 
						|
                self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
 | 
						|
 | 
						|
        model = Model()
 | 
						|
        optimizer = opt.Adam(learning_rate=3e-4)
 | 
						|
        optimizer.init(model.trainable_parameters())
 | 
						|
 | 
						|
        # Flatten the state for serialization
 | 
						|
        state = tree_flatten(optimizer.state)
 | 
						|
 | 
						|
        # Make a new optimizer and load the state
 | 
						|
        optimizer = opt.Adam(learning_rate=3e-4)
 | 
						|
        optimizer.state = tree_unflatten(state)
 | 
						|
 | 
						|
        # This should work without any errors
 | 
						|
        grads = model.trainable_parameters()
 | 
						|
        optimizer.update(model, grads)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |