try again with checkpointed classes

This commit is contained in:
Awni Hannun 2024-03-06 10:36:14 -08:00
parent a5827d0384
commit 0dbe80a024

View File

@ -48,25 +48,35 @@ def checkpoint(module: Module):
checkpointing.
Returns:
The module that saves the inputs and outputs during the forward pass
A new module that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
fn = module.__call__
t = type(module)
cp_name = f"__checkpointed_{id(t)}__"
cp_class = globals().get(cp_name, None)
if cp_class is None:
def init(self):
pass
def call(self, *args, **kwargs):
return self.__checkpointed_call__(
self.trainable_parameters(), *args, **kwargs
)
cp_class = type(t.__name__, (t,), {})
cp_class.__init__ = init
cp_class.__call__ = call
globals()[cp_name] = cp_class
cp_module = cp_class()
cp_module.__dict__.update(module.__dict__)
super(Module, cp_module).update(module.state)
def inner_fn(params, *args, **kwargs):
module.update(params)
return fn(*args, **kwargs)
return module(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn)
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
class _(type(module)):
def __call__(self, *args, **kwargs):
return wrapped_checkpointed_fn(*args, **kwargs)
module.__class__ = _
return module
cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn)
return cp_module