diff --git a/whisper/convert.py b/whisper/convert.py index 2e4ebce5..c15623d1 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -206,6 +206,44 @@ def torch_to_mlx( return mlx_model +def upload_to_hub(path: str, name: str, torch_name_or_path: str): + import os + + from huggingface_hub import HfApi, ModelCard, logging + + repo_id = f"mlx-community/{name}" + text = f""" +--- +library_name: mlx +--- + +# {name} +This model was converted to MLX format from [`{torch_name_or_path}`](). + +## Use with mlx +```bash +git clone https://github.com/ml-explore/mlx-examples.git +cd mlx-examples/whisper/ +pip install -r requirements.txt + +>> import whisper +>> whisper.transcribe("FILE_NAME") +``` +""" + card = ModelCard(text) + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=repo_id, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=repo_id, + repo_type="model", + ) + + def quantize(weights, config, args): quantized_config = copy.deepcopy(config) @@ -238,7 +276,7 @@ if __name__ == "__main__": parser.add_argument( "--mlx-path", type=str, - default="mlx_models/tiny", + default="mlx_models", help="The path to save the MLX model.", ) parser.add_argument( @@ -265,6 +303,13 @@ if __name__ == "__main__": type=int, default=4, ) + parser.add_argument( + "--upload-name", + help="The name of model to upload to Hugging Face MLX Community", + type=str, + default=None, + ) + args = parser.parse_args() assert ( @@ -292,3 +337,6 @@ if __name__ == "__main__": with open(str(mlx_path / "config.json"), "w") as f: config["model_type"] = "whisper" json.dump(config, f, indent=4) + + if args.upload_name is not None: + upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path) diff --git a/whisper/requirements.txt b/whisper/requirements.txt index 3baed45b..e4dbf8d1 100644 --- a/whisper/requirements.txt +++ b/whisper/requirements.txt @@ -5,3 +5,5 @@ torch tqdm more-itertools tiktoken==0.3.3 +huggingface_hub +scipy