mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
FLUX: save final adapters file
This commit is contained in:
parent
130cdae48e
commit
7df99c6b71
@ -185,7 +185,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
|
||||
|
||||
```shell
|
||||
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
|
||||
--adapter mlx_output/0001200_adapters.safetensors \
|
||||
--adapter mlx_output/final_adapter.safetensors \
|
||||
--fuse-adapter \
|
||||
--no-t5-padding \
|
||||
'A photo of an sks dog lying on the sand at a beach in Greece'
|
||||
|
@ -141,10 +141,10 @@ def generate_progress_images(iteration, flux, args):
|
||||
im.save(out_file)
|
||||
|
||||
|
||||
def save_adapters(iteration, flux, args):
|
||||
def save_adapters(adapter_name, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
|
||||
out_file = out_dir / adapter_name
|
||||
print(f"Saving {str(out_file)}")
|
||||
|
||||
mx.save_safetensors(
|
||||
@ -375,8 +375,13 @@ if __name__ == "__main__":
|
||||
generate_progress_images(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % args.checkpoint_every == 0:
|
||||
save_adapters(i + 1, flux, args)
|
||||
file_name = f"{i + 1:07d}_adapters.safetensors"
|
||||
save_adapters(file_name, flux, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
tic = time.time()
|
||||
|
||||
final_adapter = "final_adapter.safetensors"
|
||||
save_adapters(final_adapter, flux, args)
|
||||
print(f"Training successful. Saved final weights to {args.adapter_file}.")
|
||||
|
@ -18,6 +18,7 @@ from .utils import (
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
save_config,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user