mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	try again with checkpointed classes
This commit is contained in:
		| @@ -48,25 +48,35 @@ def checkpoint(module: Module): | ||||
|           checkpointing. | ||||
|  | ||||
|     Returns: | ||||
|         The module that saves the inputs and outputs during the forward pass | ||||
|         A new module that saves the inputs and outputs during the forward pass | ||||
|         and recomputes all intermediate states during the backward pass. | ||||
|     """ | ||||
|  | ||||
|     fn = module.__call__ | ||||
|     t = type(module) | ||||
|     cp_name = f"__checkpointed_{id(t)}__" | ||||
|     cp_class = globals().get(cp_name, None) | ||||
|     if cp_class is None: | ||||
|  | ||||
|         def init(self): | ||||
|             pass | ||||
|  | ||||
|         def call(self, *args, **kwargs): | ||||
|             return self.__checkpointed_call__( | ||||
|                 self.trainable_parameters(), *args, **kwargs | ||||
|             ) | ||||
|  | ||||
|         cp_class = type(t.__name__, (t,), {}) | ||||
|         cp_class.__init__ = init | ||||
|         cp_class.__call__ = call | ||||
|         globals()[cp_name] = cp_class | ||||
|  | ||||
|     cp_module = cp_class() | ||||
|     cp_module.__dict__.update(module.__dict__) | ||||
|     super(Module, cp_module).update(module.state) | ||||
|  | ||||
|     def inner_fn(params, *args, **kwargs): | ||||
|         module.update(params) | ||||
|         return fn(*args, **kwargs) | ||||
|         return module(*args, **kwargs) | ||||
|  | ||||
|     checkpointed_fn = mx.checkpoint(inner_fn) | ||||
|  | ||||
|     @wraps(fn) | ||||
|     def wrapped_checkpointed_fn(*args, **kwargs): | ||||
|         return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) | ||||
|  | ||||
|     class _(type(module)): | ||||
|         def __call__(self, *args, **kwargs): | ||||
|             return wrapped_checkpointed_fn(*args, **kwargs) | ||||
|  | ||||
|     module.__class__ = _ | ||||
|     return module | ||||
|     cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn) | ||||
|     return cp_module | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun