mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
reduction moved to CPU in case of distributed training (#1200)
This commit is contained in:
@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
|
||||
@contextmanager
|
||||
def swapped_with_identity(obj, func):
|
||||
old_func = getattr(obj, func)
|
||||
setattr(obj, func, lambda x: x)
|
||||
setattr(obj, func, lambda x, **kwargs: x)
|
||||
yield
|
||||
setattr(obj, func, old_func)
|
||||
|
||||
|
Reference in New Issue
Block a user