mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 02:16:37 +08:00
[Whisper] Add HF Hub upload option. (#254)
* Add HF Hub upload option. * up. * Add missing requirements.
This commit is contained in:
parent
6e5b0de4d3
commit
d4c3a9cb54
@ -206,6 +206,44 @@ def torch_to_mlx(
|
|||||||
return mlx_model
|
return mlx_model
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, ModelCard, logging
|
||||||
|
|
||||||
|
repo_id = f"mlx-community/{name}"
|
||||||
|
text = f"""
|
||||||
|
---
|
||||||
|
library_name: mlx
|
||||||
|
---
|
||||||
|
|
||||||
|
# {name}
|
||||||
|
This model was converted to MLX format from [`{torch_name_or_path}`]().
|
||||||
|
|
||||||
|
## Use with mlx
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ml-explore/mlx-examples.git
|
||||||
|
cd mlx-examples/whisper/
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
>> import whisper
|
||||||
|
>> whisper.transcribe("FILE_NAME")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
card = ModelCard(text)
|
||||||
|
card.save(os.path.join(path, "README.md"))
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def quantize(weights, config, args):
|
def quantize(weights, config, args):
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
|
|
||||||
@ -238,7 +276,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mlx-path",
|
"--mlx-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="mlx_models/tiny",
|
default="mlx_models",
|
||||||
help="The path to save the MLX model.",
|
help="The path to save the MLX model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -265,6 +303,13 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload-name",
|
||||||
|
help="The name of model to upload to Hugging Face MLX Community",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -292,3 +337,6 @@ if __name__ == "__main__":
|
|||||||
with open(str(mlx_path / "config.json"), "w") as f:
|
with open(str(mlx_path / "config.json"), "w") as f:
|
||||||
config["model_type"] = "whisper"
|
config["model_type"] = "whisper"
|
||||||
json.dump(config, f, indent=4)
|
json.dump(config, f, indent=4)
|
||||||
|
|
||||||
|
if args.upload_name is not None:
|
||||||
|
upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path)
|
||||||
|
@ -5,3 +5,5 @@ torch
|
|||||||
tqdm
|
tqdm
|
||||||
more-itertools
|
more-itertools
|
||||||
tiktoken==0.3.3
|
tiktoken==0.3.3
|
||||||
|
huggingface_hub
|
||||||
|
scipy
|
||||||
|
Loading…
Reference in New Issue
Block a user