diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 9bd572e3..f0d8e0a4 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -189,7 +189,6 @@ def train( state = [model.state, optimizer.state] - @partial(mx.compile, inputs=state, outputs=state) def step(batch): # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch)