diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index dd2a8b67..38619d95 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -11,7 +11,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_flatten def grad_checkpoint(layer):