reduction moved to CPU in case of distributed training (#1200)

This commit is contained in:
Ivan Fioravanti
2025-01-15 02:20:42 +01:00
committed by GitHub
parent c117af83b8
commit 6ae6c72c2e
2 changed files with 5 additions and 5 deletions

View File

@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
setattr(obj, func, lambda x, **kwargs: x)
yield
setattr(obj, func, old_func)