mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
try again with checkpointed classes
This commit is contained in:
parent
a5827d0384
commit
0dbe80a024
@ -48,25 +48,35 @@ def checkpoint(module: Module):
|
|||||||
checkpointing.
|
checkpointing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The module that saves the inputs and outputs during the forward pass
|
A new module that saves the inputs and outputs during the forward pass
|
||||||
and recomputes all intermediate states during the backward pass.
|
and recomputes all intermediate states during the backward pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fn = module.__call__
|
t = type(module)
|
||||||
|
cp_name = f"__checkpointed_{id(t)}__"
|
||||||
|
cp_class = globals().get(cp_name, None)
|
||||||
|
if cp_class is None:
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def call(self, *args, **kwargs):
|
||||||
|
return self.__checkpointed_call__(
|
||||||
|
self.trainable_parameters(), *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
cp_class = type(t.__name__, (t,), {})
|
||||||
|
cp_class.__init__ = init
|
||||||
|
cp_class.__call__ = call
|
||||||
|
globals()[cp_name] = cp_class
|
||||||
|
|
||||||
|
cp_module = cp_class()
|
||||||
|
cp_module.__dict__.update(module.__dict__)
|
||||||
|
super(Module, cp_module).update(module.state)
|
||||||
|
|
||||||
def inner_fn(params, *args, **kwargs):
|
def inner_fn(params, *args, **kwargs):
|
||||||
module.update(params)
|
module.update(params)
|
||||||
return fn(*args, **kwargs)
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
checkpointed_fn = mx.checkpoint(inner_fn)
|
cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn)
|
||||||
|
return cp_module
|
||||||
@wraps(fn)
|
|
||||||
def wrapped_checkpointed_fn(*args, **kwargs):
|
|
||||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
|
||||||
|
|
||||||
class _(type(module)):
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return wrapped_checkpointed_fn(*args, **kwargs)
|
|
||||||
|
|
||||||
module.__class__ = _
|
|
||||||
return module
|
|
||||||
|
Loading…
Reference in New Issue
Block a user