FLUX: save final adapters file

This commit is contained in:
madroid 2024-10-15 13:49:44 +08:00
parent 130cdae48e
commit 7df99c6b71
4 changed files with 11 additions and 5 deletions

View File

@ -185,7 +185,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
```shell ```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/0001200_adapters.safetensors \ --adapter mlx_output/final_adapter.safetensors \
--fuse-adapter \ --fuse-adapter \
--no-t5-padding \ --no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece' 'A photo of an sks dog lying on the sand at a beach in Greece'

View File

@ -141,10 +141,10 @@ def generate_progress_images(iteration, flux, args):
im.save(out_file) im.save(out_file)
def save_adapters(iteration, flux, args): def save_adapters(adapter_name, flux, args):
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors" out_file = out_dir / adapter_name
print(f"Saving {str(out_file)}") print(f"Saving {str(out_file)}")
mx.save_safetensors( mx.save_safetensors(
@ -375,8 +375,13 @@ if __name__ == "__main__":
generate_progress_images(i + 1, flux, args) generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0: if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args) file_name = f"{i + 1:07d}_adapters.safetensors"
save_adapters(file_name, flux, args)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] losses = []
tic = time.time() tic = time.time()
final_adapter = "final_adapter.safetensors"
save_adapters(final_adapter, flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")

View File

@ -18,6 +18,7 @@ from .utils import (
load_flow_model, load_flow_model,
load_t5, load_t5,
load_t5_tokenizer, load_t5_tokenizer,
save_config,
) )