mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 06:21:12 +08:00
try again with checkpointed classes
This commit is contained in:
parent
a5827d0384
commit
0dbe80a024
@ -48,25 +48,35 @@ def checkpoint(module: Module):
|
||||
checkpointing.
|
||||
|
||||
Returns:
|
||||
The module that saves the inputs and outputs during the forward pass
|
||||
A new module that saves the inputs and outputs during the forward pass
|
||||
and recomputes all intermediate states during the backward pass.
|
||||
"""
|
||||
|
||||
fn = module.__call__
|
||||
t = type(module)
|
||||
cp_name = f"__checkpointed_{id(t)}__"
|
||||
cp_class = globals().get(cp_name, None)
|
||||
if cp_class is None:
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return self.__checkpointed_call__(
|
||||
self.trainable_parameters(), *args, **kwargs
|
||||
)
|
||||
|
||||
cp_class = type(t.__name__, (t,), {})
|
||||
cp_class.__init__ = init
|
||||
cp_class.__call__ = call
|
||||
globals()[cp_name] = cp_class
|
||||
|
||||
cp_module = cp_class()
|
||||
cp_module.__dict__.update(module.__dict__)
|
||||
super(Module, cp_module).update(module.state)
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
module.update(params)
|
||||
return fn(*args, **kwargs)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
checkpointed_fn = mx.checkpoint(inner_fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_checkpointed_fn(*args, **kwargs):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
class _(type(module)):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return wrapped_checkpointed_fn(*args, **kwargs)
|
||||
|
||||
module.__class__ = _
|
||||
return module
|
||||
cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn)
|
||||
return cp_module
|
||||
|
Loading…
Reference in New Issue
Block a user