diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 64e26af8..35f18d29 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,20 +1,14 @@ # Copyright © 2024 Apple Inc. -import glob -import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten -from transformers import PreTrainedTokenizer - -from .datasets import CompletionsDataset def grad_checkpoint(layer):