mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
		
			
				
	
	
		
			585 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			585 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023 Apple Inc.
 | 
						|
 | 
						|
import inspect
 | 
						|
import math
 | 
						|
import unittest
 | 
						|
from functools import partial
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx.nn as nn
 | 
						|
import mlx.optimizers as opt
 | 
						|
import mlx.utils
 | 
						|
import mlx_tests
 | 
						|
import numpy as np
 | 
						|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
 | 
						|
 | 
						|
try:
 | 
						|
    import torch
 | 
						|
    import torch.nn.functional as F
 | 
						|
 | 
						|
    has_torch = True
 | 
						|
except ImportError as e:
 | 
						|
    has_torch = False
 | 
						|
 | 
						|
 | 
						|
def get_all_optimizers():
 | 
						|
    classes = dict()
 | 
						|
    for name, obj in inspect.getmembers(opt):
 | 
						|
        if (
 | 
						|
            inspect.isclass(obj)
 | 
						|
            and issubclass(obj, opt.Optimizer)
 | 
						|
            and obj != opt.Optimizer
 | 
						|
        ):
 | 
						|
            classes[name] = obj
 | 
						|
    return classes
 | 
						|
 | 
						|
 | 
						|
def tree_equal(fn, *args):
 | 
						|
    return all(v for _, v in tree_flatten(tree_map(fn, *args)))
 | 
						|
 | 
						|
 | 
						|
optimizers_dict = get_all_optimizers()
 | 
						|
del optimizers_dict["MultiOptimizer"]
 | 
						|
 | 
						|
 | 
						|
class TestOptimizers(mlx_tests.MLXTestCase):
 | 
						|
    def test_optimizer_state(self):
 | 
						|
        optim = opt.SGD(0.1)
 | 
						|
        optim.state["hello"] = "world"
 | 
						|
        self.assertEqual(optim.state["hello"], "world")
 | 
						|
 | 
						|
        optim.state = {0: 1}
 | 
						|
        self.assertEqual(optim.state, {0: 1})
 | 
						|
 | 
						|
    def test_optimizers(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        for optim_class in optimizers_dict.values():
 | 
						|
            optim = optim_class(0.1)
 | 
						|
            update = optim.apply_gradients(grads, params)
 | 
						|
            mx.eval(update)
 | 
						|
            equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
 | 
						|
            all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
 | 
						|
            self.assertTrue(all_equal)
 | 
						|
 | 
						|
    def test_types_conserved(self):
 | 
						|
        params = {"w": mx.ones((5, 5), mx.float16)}
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
        for optim_class in optimizers_dict.values():
 | 
						|
            optim = optim_class(0.1)
 | 
						|
            update = optim.apply_gradients(grads, params)
 | 
						|
            self.assertEqual(update["w"].dtype, mx.float16)
 | 
						|
 | 
						|
    def test_sgd(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Implicit init
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
        optim.apply_gradients(grads, params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
 | 
						|
        )
 | 
						|
 | 
						|
    def test_rmsprop(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.RMSprop(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Implicit init
 | 
						|
        alpha = 0.99
 | 
						|
        optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
 | 
						|
        optim.apply_gradients(grads, params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adagrad(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.Adagrad(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adadelta(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.AdaDelta(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adam(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
 | 
						|
            optim = optimizer(learning_rate=1e-2)
 | 
						|
            optim.init(params)
 | 
						|
            self.assertTrue(
 | 
						|
                tree_equal(
 | 
						|
                    lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                    params,
 | 
						|
                    optim.state,
 | 
						|
                )
 | 
						|
            )
 | 
						|
            self.assertTrue(
 | 
						|
                tree_equal(
 | 
						|
                    lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
 | 
						|
                    params,
 | 
						|
                    optim.state,
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        # Test for correct gradient type propagation
 | 
						|
        params = tree_map(lambda x: x.astype(mx.float16), params)
 | 
						|
        grads = tree_map(lambda x: x.astype(mx.float16), grads)
 | 
						|
        optim = opt.Adam(1e-2, bias_correction=True)
 | 
						|
        new_params = optim.apply_gradients(grads, params)
 | 
						|
        self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
 | 
						|
 | 
						|
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
						|
    def test_adamw_matches_pytorch(self):
 | 
						|
        mx.random.seed(0)
 | 
						|
        np.random.seed(0)
 | 
						|
 | 
						|
        model = nn.Linear(3, 1)
 | 
						|
        init_weight = np.array(model.weight.tolist())
 | 
						|
        init_bias = np.array(model.bias.tolist())
 | 
						|
 | 
						|
        def loss_fn(model, x, y):
 | 
						|
            pred = model(x)
 | 
						|
            return nn.losses.mse_loss(pred, y)
 | 
						|
 | 
						|
        x = np.random.rand(3, 3)
 | 
						|
        y = np.random.rand(3, 1)
 | 
						|
 | 
						|
        optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
 | 
						|
        loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
 | 
						|
        loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
 | 
						|
        optimizer.update(model, grads)
 | 
						|
 | 
						|
        # Equivalent torch code
 | 
						|
        torch_model = torch.nn.Linear(3, 1)
 | 
						|
 | 
						|
        # copy over the parameters
 | 
						|
        torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
 | 
						|
        torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
 | 
						|
 | 
						|
        torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
 | 
						|
        torch_optimizer.zero_grad()
 | 
						|
        pred = torch_model(torch.tensor(x, dtype=torch.float32))
 | 
						|
        loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
 | 
						|
        loss.backward()
 | 
						|
        torch_optimizer.step()
 | 
						|
 | 
						|
        for name, param in torch_model.named_parameters():
 | 
						|
            mlx_grad = np.array(grads[name])
 | 
						|
            torch_grad = param.grad.detach().numpy()
 | 
						|
            self.assertTrue(np.allclose(torch_grad, mlx_grad))
 | 
						|
 | 
						|
        for name, param in torch_model.named_parameters():
 | 
						|
            mlx_param = np.array(model[name])
 | 
						|
            torch_param = param.data.detach().numpy()
 | 
						|
            self.assertTrue(np.allclose(torch_param, mlx_param))
 | 
						|
 | 
						|
    def test_lion(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10,)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((1,)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.Lion(learning_rate=1e-2)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_adafactor(self):
 | 
						|
        x = mx.zeros((5, 5))
 | 
						|
        params = {"x": x}
 | 
						|
        grad = {"x": mx.ones_like(x)}
 | 
						|
        optimizer = opt.Adafactor()
 | 
						|
        for _ in range(2):
 | 
						|
            xp = optimizer.apply_gradients(grad, params)
 | 
						|
            self.assertEqual(xp["x"].dtype, x.dtype)
 | 
						|
            self.assertEqual(xp["x"].shape, x.shape)
 | 
						|
 | 
						|
        x = mx.zeros((5, 5), mx.float16)
 | 
						|
        params = {"x": x}
 | 
						|
        grad = {"x": mx.ones_like(x)}
 | 
						|
        optimizer = opt.Adafactor()
 | 
						|
        for _ in range(2):
 | 
						|
            xp = optimizer.apply_gradients(grad, params)
 | 
						|
            self.assertEqual(xp["x"].dtype, x.dtype)
 | 
						|
            self.assertEqual(xp["x"].shape, x.shape)
 | 
						|
        self.assertEqual(optimizer.state["step"], 2)
 | 
						|
 | 
						|
    def test_muon(self):
 | 
						|
        params = {
 | 
						|
            "first": [mx.zeros((10, 5)), mx.zeros((1,))],
 | 
						|
            "second": mx.zeros((3, 3)),
 | 
						|
            "conv": mx.zeros((16, 8, 3, 3)),
 | 
						|
        }
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
        # Explicit init
 | 
						|
        optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
 | 
						|
        optim.init(params)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
						|
                params,
 | 
						|
                optim.state,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Test update
 | 
						|
        updated_params = optim.apply_gradients(grads, params)
 | 
						|
 | 
						|
        # Check that shapes are preserved
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, u: p.shape == u.shape,
 | 
						|
                params,
 | 
						|
                updated_params,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Check that parameters actually changed
 | 
						|
        self.assertFalse(
 | 
						|
            tree_equal(
 | 
						|
                lambda p, u: mx.array_equal(p, u),
 | 
						|
                params,
 | 
						|
                updated_params,
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        # Test with different configurations
 | 
						|
        optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
 | 
						|
        optim_no_nesterov.apply_gradients(grads, params)
 | 
						|
 | 
						|
        optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
 | 
						|
        optim_no_momentum.apply_gradients(grads, params)
 | 
						|
 | 
						|
    def test_compiled_optimizer(self):
 | 
						|
        model = nn.Linear(10, 10)
 | 
						|
        x = mx.random.uniform(shape=(2, 10))
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
 | 
						|
        orig_params = model.parameters()
 | 
						|
 | 
						|
        def loss(model, x):
 | 
						|
            return model(x).sum()
 | 
						|
 | 
						|
        # Uncompiled version
 | 
						|
        def step(x):
 | 
						|
            _, grad = nn.value_and_grad(model, loss)(model, x)
 | 
						|
            optim.update(model, grad)
 | 
						|
 | 
						|
        step(x)
 | 
						|
        uncompiled_params = model.parameters()
 | 
						|
 | 
						|
        # Pure version
 | 
						|
        def loss(params, x):
 | 
						|
            model.update(params)
 | 
						|
            return model(x).sum()
 | 
						|
 | 
						|
        model.update(orig_params)
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
 | 
						|
        @mx.compile
 | 
						|
        def step(params, opt_state, x):
 | 
						|
            grad = mx.grad(loss)(params, x)
 | 
						|
            optim.state = opt_state
 | 
						|
            params = optim.apply_gradients(grad, params)
 | 
						|
            return params, optim.state
 | 
						|
 | 
						|
        optim.init(model.parameters())
 | 
						|
        pure_params, _ = step(model.parameters(), optim.state, x)
 | 
						|
        self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
 | 
						|
        self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
 | 
						|
 | 
						|
        # Impure version
 | 
						|
        def loss(model, x):
 | 
						|
            return model(x).sum()
 | 
						|
 | 
						|
        model.update(orig_params)
 | 
						|
        optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
 | 
						|
        state = [model.state, optim.state]
 | 
						|
 | 
						|
        @partial(mx.compile, inputs=state, outputs=state)
 | 
						|
        def step(x):
 | 
						|
            _, grad = nn.value_and_grad(model, loss)(model, x)
 | 
						|
            optim.update(model, grad)
 | 
						|
 | 
						|
        step(x)
 | 
						|
        impure_params = model.parameters()
 | 
						|
        self.assertTrue(
 | 
						|
            mx.allclose(impure_params["weight"], uncompiled_params["weight"])
 | 
						|
        )
 | 
						|
        self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
 | 
						|
 | 
						|
    def test_update_lr_compiled(self):
 | 
						|
        params = {"w": mx.ones((5, 5))}
 | 
						|
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
        optim = opt.SGD(-1.0)
 | 
						|
 | 
						|
        @partial(mx.compile, inputs=optim.state)
 | 
						|
        def update(grads):
 | 
						|
            return optim.apply_gradients(grads, params)
 | 
						|
 | 
						|
        result = update(grads)
 | 
						|
        self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
 | 
						|
        optim.learning_rate = -2.0
 | 
						|
        result = update(grads)
 | 
						|
        self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
 | 
						|
 | 
						|
 | 
						|
class TestSchedulers(mlx_tests.MLXTestCase):
 | 
						|
    def test_decay_lr(self):
 | 
						|
        for optim_class in optimizers_dict.values():
 | 
						|
            lr_schedule = opt.step_decay(1e-1, 0.9, 1)
 | 
						|
            optimizer = optim_class(learning_rate=lr_schedule)
 | 
						|
 | 
						|
            params = {"w": mx.ones((5, 5))}
 | 
						|
            grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
						|
 | 
						|
            for it in range(10):
 | 
						|
                optimizer.apply_gradients(grads, params)
 | 
						|
                expected_lr = 0.1 * (0.9**it)
 | 
						|
                self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
    def test_step_decay(self):
 | 
						|
        lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
 | 
						|
        lr = lr_schedule(2500)
 | 
						|
        expected_lr = 0.1 * (0.9**2)
 | 
						|
        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
    def test_exponential_decay(self):
 | 
						|
        lr_schedule = opt.exponential_decay(1e-1, 0.99)
 | 
						|
        lr = lr_schedule(10)
 | 
						|
        expected_lr = 0.1 * (0.99**10)
 | 
						|
        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
    def test_cosine_decay(self):
 | 
						|
        lr_schedule = opt.cosine_decay(0.1, 10)
 | 
						|
        lr = lr_schedule(4)
 | 
						|
        expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
 | 
						|
        self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
 | 
						|
 | 
						|
        lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
 | 
						|
        lr = lr_schedule(9)
 | 
						|
        expected_end_lr = 0.05
 | 
						|
        self.assertGreater(lr, expected_end_lr)
 | 
						|
        lr = lr_schedule(20)
 | 
						|
        self.assertEqual(lr, expected_end_lr)
 | 
						|
 | 
						|
    def test_schedule_joiner(self):
 | 
						|
        boundaries = [2, 3, 4]
 | 
						|
        schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
 | 
						|
        with self.assertRaises(ValueError):
 | 
						|
            opt.schedulers.join_schedules(schedules, boundaries)
 | 
						|
        boundaries = [2, 4]
 | 
						|
        schedule = opt.schedulers.join_schedules(schedules, boundaries)
 | 
						|
        self.assertEqual(schedule(0).item(), 3)
 | 
						|
        self.assertEqual(schedule(1).item(), 3)
 | 
						|
        self.assertEqual(schedule(2).item(), 4)
 | 
						|
        self.assertEqual(schedule(3).item(), 4)
 | 
						|
        self.assertEqual(schedule(5).item(), 5)
 | 
						|
        self.assertEqual(schedule(7).item(), 5)
 | 
						|
 | 
						|
    def test_linear_warmup_with_cosine_decay(self):
 | 
						|
        warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
 | 
						|
        cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
 | 
						|
        cos_with_warmup = opt.schedulers.join_schedules(
 | 
						|
            [warmup_schedule, cosine_schedule], [101]
 | 
						|
        )
 | 
						|
        self.assertEqual(cos_with_warmup(0), 0.0)
 | 
						|
        self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
 | 
						|
        optimizer = opt.Adam(learning_rate=cos_with_warmup)
 | 
						|
        for _ in range(100):
 | 
						|
            optimizer.update({}, {})
 | 
						|
        self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
 | 
						|
        for _ in range(100):
 | 
						|
            optimizer.update({}, {})
 | 
						|
        expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
 | 
						|
        self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
 | 
						|
 | 
						|
    def test_compile_with_schedule(self):
 | 
						|
        lr_schedule = opt.exponential_decay(1e-1, 0.9)
 | 
						|
        optimizer = opt.SGD(learning_rate=lr_schedule)
 | 
						|
 | 
						|
        @partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
 | 
						|
        def update():
 | 
						|
            optimizer.update({}, {})
 | 
						|
 | 
						|
        for step in range(5):
 | 
						|
            update()
 | 
						|
            self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
 | 
						|
 | 
						|
    def test_clip_grad_norm(self):
 | 
						|
        # Test with small gradients that do not require clipping
 | 
						|
        small_grads = {
 | 
						|
            "first": [mx.array([0.1, 0.2]), mx.array([0.1])],
 | 
						|
            "second": mx.array([0.3]),
 | 
						|
        }
 | 
						|
        max_norm = 10.0  # A large max_norm that shouldn't trigger clipping
 | 
						|
        clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),
 | 
						|
            "Gradients should not be modified when clipping is not necessary.",
 | 
						|
        )
 | 
						|
 | 
						|
        # Test with large gradients that require clipping
 | 
						|
        large_grads = {
 | 
						|
            "first": [mx.array([10, 20]), mx.array([10])],
 | 
						|
            "second": mx.array([30]),
 | 
						|
        }
 | 
						|
        max_norm = 1.0  # A small max_norm that should trigger clipping
 | 
						|
        clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)
 | 
						|
        # Correctly extract only the gradient values for norm calculation
 | 
						|
        clipped_values = [value for _, value in tree_flatten(clipped_grads)]
 | 
						|
        norm_of_clipped = mx.sqrt(
 | 
						|
            sum(mx.square(g).sum() for g in clipped_values)
 | 
						|
        ).item()
 | 
						|
        self.assertAlmostEqual(
 | 
						|
            norm_of_clipped,
 | 
						|
            max_norm,
 | 
						|
            places=6,
 | 
						|
            msg="Clipped gradients norm should be close to the specified max_norm.",
 | 
						|
        )
 | 
						|
 | 
						|
        # Ensures that the scaling was done correctly
 | 
						|
        scale = max_norm / total_norm
 | 
						|
        expected_grads = tree_map(lambda g: g * scale, large_grads)
 | 
						|
        self.assertTrue(
 | 
						|
            tree_equal(
 | 
						|
                lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads
 | 
						|
            ),
 | 
						|
            "Gradients were not scaled correctly during clipping.",
 | 
						|
        )
 | 
						|
 | 
						|
    def test_init_from_state(self):
 | 
						|
        class Model(nn.Module):
 | 
						|
            def __init__(self):
 | 
						|
                super().__init__()
 | 
						|
                self.l1 = nn.Linear(2, 2)
 | 
						|
                self.drop = nn.Dropout(p=0.5)
 | 
						|
                self.l2 = nn.Linear(2, 2)
 | 
						|
                self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
 | 
						|
 | 
						|
        model = Model()
 | 
						|
        optimizer = opt.Adam(learning_rate=3e-4)
 | 
						|
        optimizer.init(model.trainable_parameters())
 | 
						|
 | 
						|
        # Flatten the state for serialization
 | 
						|
        state = tree_flatten(optimizer.state)
 | 
						|
 | 
						|
        # Make a new optimizer and load the state
 | 
						|
        optimizer = opt.Adam(learning_rate=3e-4)
 | 
						|
        optimizer.state = tree_unflatten(state)
 | 
						|
 | 
						|
        # This should work without any errors
 | 
						|
        grads = model.trainable_parameters()
 | 
						|
        optimizer.update(model, grads)
 | 
						|
 | 
						|
    def test_multi_optimizer(self):
 | 
						|
        class Model(nn.Module):
 | 
						|
            def __init__(self):
 | 
						|
                super().__init__()
 | 
						|
                self.l1 = nn.Linear(2, 2)
 | 
						|
                self.drop = nn.Dropout(p=0.5)
 | 
						|
                self.l2 = nn.Linear(2, 2)
 | 
						|
                self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
 | 
						|
 | 
						|
        model = Model()
 | 
						|
        optimizer = opt.MultiOptimizer(
 | 
						|
            [opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
 | 
						|
            [lambda name, weight: weight.ndim > 1],
 | 
						|
        )
 | 
						|
        optimizer.init(model.trainable_parameters())
 | 
						|
 | 
						|
        self.assertEqual(len(optimizer.state["states"]), 2)
 | 
						|
 | 
						|
        adam_states = tree_flatten(optimizer.state["states"][0])
 | 
						|
        sgd_states = tree_flatten(optimizer.state["states"][1])
 | 
						|
        self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
 | 
						|
        self.assertFalse(any("bias" in k for k, v in adam_states))
 | 
						|
        self.assertFalse(any("weight" in k for k, v in sgd_states))
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    mlx_tests.MLXTestRunner()
 |