diff --git a/flux/README.md b/flux/README.md index 0496c71b..3d836609 100644 --- a/flux/README.md +++ b/flux/README.md @@ -185,7 +185,7 @@ The adapters are saved in `mlx_output` and can be used directly by the ```shell python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ - --adapter mlx_output/0001200_adapters.safetensors \ + --adapter mlx_output/final_adapter.safetensors \ --fuse-adapter \ --no-t5-padding \ 'A photo of an sks dog lying on the sand at a beach in Greece' diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 42250a3f..afc02305 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -141,10 +141,10 @@ def generate_progress_images(iteration, flux, args): im.save(out_file) -def save_adapters(iteration, flux, args): +def save_adapters(adapter_name, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + out_file = out_dir / adapter_name print(f"Saving {str(out_file)}") mx.save_safetensors( @@ -375,8 +375,13 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - save_adapters(i + 1, flux, args) + file_name = f"{i + 1:07d}_adapters.safetensors" + save_adapters(file_name, flux, args) if (i + 1) % 10 == 0: losses = [] tic = time.time() + + final_adapter = "final_adapter.safetensors" + save_adapters(final_adapter, flux, args) + print(f"Training successful. Saved final weights to {args.adapter_file}.") diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 8d39d605..aca1fad2 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -18,6 +18,7 @@ from .utils import ( load_flow_model, load_t5, load_t5_tokenizer, + save_config, ) diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 43239f44..ffd99176 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -230,4 +230,4 @@ def save_config( # write the updated config to the config_path (if provided) with open(config_path, "w") as fid: - json.dump(config, fid, indent=4) \ No newline at end of file + json.dump(config, fid, indent=4)