This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 02:03:50 +01:00
parent 9ede9db19b
commit 424cb854e9

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from mlx.nn.utils import average_gradients
from .trainer import TrainingArgs, grad_checkpoint
from .trainer import TrainingArgs, grad_checkpoint, TrainingCallback
@dataclass