try again with checkpointed classes

This commit is contained in:
Awni Hannun 2024-03-06 10:36:14 -08:00
parent a5827d0384
commit 0dbe80a024

View File

@ -48,25 +48,35 @@ def checkpoint(module: Module):
checkpointing. checkpointing.
Returns: Returns:
The module that saves the inputs and outputs during the forward pass A new module that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass. and recomputes all intermediate states during the backward pass.
""" """
fn = module.__call__ t = type(module)
cp_name = f"__checkpointed_{id(t)}__"
cp_class = globals().get(cp_name, None)
if cp_class is None:
def init(self):
pass
def call(self, *args, **kwargs):
return self.__checkpointed_call__(
self.trainable_parameters(), *args, **kwargs
)
cp_class = type(t.__name__, (t,), {})
cp_class.__init__ = init
cp_class.__call__ = call
globals()[cp_name] = cp_class
cp_module = cp_class()
cp_module.__dict__.update(module.__dict__)
super(Module, cp_module).update(module.state)
def inner_fn(params, *args, **kwargs): def inner_fn(params, *args, **kwargs):
module.update(params) module.update(params)
return fn(*args, **kwargs) return module(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn) cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn)
return cp_module
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
class _(type(module)):
def __call__(self, *args, **kwargs):
return wrapped_checkpointed_fn(*args, **kwargs)
module.__class__ = _
return module