mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
FLUX: save train config (#1049)
This commit is contained in:
@@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
|
||||
from flux import FluxPipeline, Trainer, load_dataset
|
||||
from flux import FluxPipeline, Trainer, load_dataset, save_config
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
@@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args):
|
||||
im.save(out_file)
|
||||
|
||||
|
||||
def save_adapters(iteration, flux, args):
|
||||
def save_adapters(adapter_name, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
|
||||
out_file = out_dir / adapter_name
|
||||
print(f"Saving {str(out_file)}")
|
||||
|
||||
mx.save_safetensors(
|
||||
@@ -157,6 +157,10 @@ if __name__ == "__main__":
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
output_path = Path(args.output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
save_config(vars(args), output_path / "adapter_config.json")
|
||||
|
||||
# Load the model and set it up for LoRA training. We use the same random
|
||||
# state when creating the LoRA layers so all workers will have the same
|
||||
# initial weights.
|
||||
@@ -278,8 +282,11 @@ if __name__ == "__main__":
|
||||
generate_progress_images(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % args.checkpoint_every == 0:
|
||||
save_adapters(i + 1, flux, args)
|
||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
tic = time.time()
|
||||
|
||||
save_adapters("final_adapters.safetensors", flux, args)
|
||||
print(f"Training successful. Saved final weights to {args.adapter_file}.")
|
||||
|
||||
Reference in New Issue
Block a user