mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add push to hub
This commit is contained in:
parent
1e0cda68c6
commit
2dd903b0bf
@ -5,6 +5,7 @@ import json
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -15,6 +16,9 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from huggingface_hub import HfApi, interpreter_login
|
||||
from huggingface_hub.utils import HfFolder
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
@ -156,6 +160,108 @@ def save_adapters(iteration, flux, args):
|
||||
},
|
||||
)
|
||||
|
||||
def push_to_hub(args):
|
||||
if args.hf_token is None:
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
else:
|
||||
HfFolder.save_token(args.hf_token)
|
||||
|
||||
repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"
|
||||
|
||||
readme_content = generate_readme(args, repo_id)
|
||||
readme_path = os.path.join(args.output_dir, "README.md")
|
||||
with open(readme_path, "w", encoding="utf-8") as f:
|
||||
f.write(readme_content)
|
||||
|
||||
api = HfApi()
|
||||
|
||||
api.create_repo(
|
||||
repo_id,
|
||||
private=args.hf_private,
|
||||
exist_ok=True
|
||||
)
|
||||
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
ignore_patterns=["*.yaml", "*.pt"],
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
def generate_readme(args, repo_id):
|
||||
import yaml
|
||||
import re
|
||||
base_model = f"flux-{args.model}"
|
||||
tags = [
|
||||
"text-to-image",
|
||||
"flux",
|
||||
"lora",
|
||||
"diffusers",
|
||||
"template:sd-lora",
|
||||
"mlx",
|
||||
"mlx-trainer"
|
||||
]
|
||||
|
||||
widgets = []
|
||||
sample_image_paths = []
|
||||
# Look for progress images directly in the output directory
|
||||
for filename in os.listdir(args.output_dir):
|
||||
match = re.search(r"(\d+)_progress\.png$", filename)
|
||||
if match:
|
||||
iteration = int(match.group(1))
|
||||
sample_image_paths.append((iteration, filename))
|
||||
|
||||
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if sample_image_paths:
|
||||
widgets.append(
|
||||
{
|
||||
"text": args.progress_prompt,
|
||||
"output": {
|
||||
"url": sample_image_paths[0][1]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
readme_content = f"""---
|
||||
tags:
|
||||
{yaml.dump(tags, indent=4).strip()}
|
||||
{"widget:" if sample_image_paths else ""}
|
||||
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
|
||||
base_model: {base_model}
|
||||
license: other
|
||||
---
|
||||
|
||||
# {os.path.basename(args.output_dir)}
|
||||
Model trained with the MLX Flux Dreambooth script
|
||||
|
||||
<Gallery />
|
||||
|
||||
## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
|
||||
```py
|
||||
from flux import FluxPipeline
|
||||
import mlx.core as mx
|
||||
flux = FluxPipeline("flux-{args.model}")
|
||||
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
|
||||
flux.flow.load_weights("{repo_id}")
|
||||
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
|
||||
image.save("my_image.png")
|
||||
```
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
|
||||
pipeline.load_lora_weights('{repo_id}')
|
||||
image = pipeline({args.progress_prompt}').images[0]
|
||||
image.save("my_image.png")
|
||||
```
|
||||
|
||||
For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
|
||||
"""
|
||||
return readme_content
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -245,7 +351,28 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Push the model to Hugging Face Hub after training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face token for pushing to Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_repo_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face repository ID for pushing to Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_private",
|
||||
action="store_true",
|
||||
help="Make the Hugging Face repository private",
|
||||
)
|
||||
parser.add_argument("dataset")
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -376,3 +503,6 @@ if __name__ == "__main__":
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
tic = time.time()
|
||||
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args)
|
||||
|
Loading…
Reference in New Issue
Block a user