mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-08 18:06:37 +08:00
[Whisper] Add HF Hub upload option. (#254)
* Add HF Hub upload option. * up. * Add missing requirements.
This commit is contained in:
parent
6e5b0de4d3
commit
d4c3a9cb54
@ -206,6 +206,44 @@ def torch_to_mlx(
|
||||
return mlx_model
|
||||
|
||||
|
||||
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
repo_id = f"mlx-community/{name}"
|
||||
text = f"""
|
||||
---
|
||||
library_name: mlx
|
||||
---
|
||||
|
||||
# {name}
|
||||
This model was converted to MLX format from [`{torch_name_or_path}`]().
|
||||
|
||||
## Use with mlx
|
||||
```bash
|
||||
git clone https://github.com/ml-explore/mlx-examples.git
|
||||
cd mlx-examples/whisper/
|
||||
pip install -r requirements.txt
|
||||
|
||||
>> import whisper
|
||||
>> whisper.transcribe("FILE_NAME")
|
||||
```
|
||||
"""
|
||||
card = ModelCard(text)
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
@ -238,7 +276,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_models/tiny",
|
||||
default="mlx_models",
|
||||
help="The path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -265,6 +303,13 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-name",
|
||||
help="The name of model to upload to Hugging Face MLX Community",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (
|
||||
@ -292,3 +337,6 @@ if __name__ == "__main__":
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
config["model_type"] = "whisper"
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
if args.upload_name is not None:
|
||||
upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path)
|
||||
|
@ -5,3 +5,5 @@ torch
|
||||
tqdm
|
||||
more-itertools
|
||||
tiktoken==0.3.3
|
||||
huggingface_hub
|
||||
scipy
|
||||
|
Loading…
Reference in New Issue
Block a user