[Whisper] Add HF Hub upload option. (#254)

* Add HF Hub upload option.

* up.

* Add missing requirements.
This commit is contained in:
Vaibhav Srivastav 2024-01-08 19:48:24 +05:30 committed by GitHub
parent 6e5b0de4d3
commit d4c3a9cb54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 1 deletions

View File

@ -206,6 +206,44 @@ def torch_to_mlx(
return mlx_model
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
text = f"""
---
library_name: mlx
---
# {name}
This model was converted to MLX format from [`{torch_name_or_path}`]().
## Use with mlx
```bash
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/whisper/
pip install -r requirements.txt
>> import whisper
>> whisper.transcribe("FILE_NAME")
```
"""
card = ModelCard(text)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
@ -238,7 +276,7 @@ if __name__ == "__main__":
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_models/tiny",
default="mlx_models",
help="The path to save the MLX model.",
)
parser.add_argument(
@ -265,6 +303,13 @@ if __name__ == "__main__":
type=int,
default=4,
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
args = parser.parse_args()
assert (
@ -292,3 +337,6 @@ if __name__ == "__main__":
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)
if args.upload_name is not None:
upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path)

View File

@ -5,3 +5,5 @@ torch
tqdm
more-itertools
tiktoken==0.3.3
huggingface_hub
scipy