mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Nits
This commit is contained in:
parent
b899e81589
commit
ebf314bdcc
@ -282,13 +282,11 @@ if __name__ == "__main__":
|
||||
generate_progress_images(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % args.checkpoint_every == 0:
|
||||
file_name = f"{i + 1:07d}_adapters.safetensors"
|
||||
save_adapters(file_name, flux, args)
|
||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
tic = time.time()
|
||||
|
||||
final_adapter = "final_adapter.safetensors"
|
||||
save_adapters(final_adapter, flux, args)
|
||||
save_adapters("final_adapters.safetensors", flux, args)
|
||||
print(f"Training successful. Saved final weights to {args.adapter_file}.")
|
||||
|
@ -222,12 +222,9 @@ def save_config(
|
||||
config (dict): The model configuration.
|
||||
config_path (Union[str, Path]): Model configuration file path.
|
||||
"""
|
||||
# Clean unused keys
|
||||
config.pop("_name_or_path", None)
|
||||
|
||||
# sort the config for better readability
|
||||
# Sort the config for better readability
|
||||
config = dict(sorted(config.items()))
|
||||
|
||||
# write the updated config to the config_path (if provided)
|
||||
# Write the config to the provided file
|
||||
with open(config_path, "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
Loading…
Reference in New Issue
Block a user