From 0dbe80a0240619ffa1de5013134efc4bb1c32930 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 6 Mar 2024 10:36:14 -0800 Subject: [PATCH] try again with checkpointed classes --- python/mlx/nn/utils.py | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index ebd340599..ff8c34ea6 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -48,25 +48,35 @@ def checkpoint(module: Module): checkpointing. Returns: - The module that saves the inputs and outputs during the forward pass + A new module that saves the inputs and outputs during the forward pass and recomputes all intermediate states during the backward pass. """ - fn = module.__call__ + t = type(module) + cp_name = f"__checkpointed_{id(t)}__" + cp_class = globals().get(cp_name, None) + if cp_class is None: + + def init(self): + pass + + def call(self, *args, **kwargs): + return self.__checkpointed_call__( + self.trainable_parameters(), *args, **kwargs + ) + + cp_class = type(t.__name__, (t,), {}) + cp_class.__init__ = init + cp_class.__call__ = call + globals()[cp_name] = cp_class + + cp_module = cp_class() + cp_module.__dict__.update(module.__dict__) + super(Module, cp_module).update(module.state) def inner_fn(params, *args, **kwargs): module.update(params) - return fn(*args, **kwargs) + return module(*args, **kwargs) - checkpointed_fn = mx.checkpoint(inner_fn) - - @wraps(fn) - def wrapped_checkpointed_fn(*args, **kwargs): - return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) - - class _(type(module)): - def __call__(self, *args, **kwargs): - return wrapped_checkpointed_fn(*args, **kwargs) - - module.__class__ = _ - return module + cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn) + return cp_module