From 7575125d5dbb76213f32db10308cf200721acbfb Mon Sep 17 00:00:00 2001 From: Yousif Date: Fri, 12 Jan 2024 13:45:30 -0800 Subject: [PATCH] Added lora support for Phi-2 (#302) * Added lora support for Phi-2 * Added Phi-2 support in fuse and convert * format + readme --------- Co-authored-by: Awni Hannun --- lora/README.md | 10 +- lora/convert.py | 6 +- lora/fuse.py | 10 +- lora/lora.py | 11 +- lora/models/__init__.py | 0 lora/models/base.py | 15 +++ lora/models/llama.py | 202 +++++++++++++++++++++++++++++++++ lora/models/lora.py | 86 ++++++++++++++ lora/models/phi2.py | 138 ++++++++++++++++++++++ lora/utils.py | 100 ++++++++++++++++ stable_diffusion/txt2image.py | 2 +- whisper/whisper/load_models.py | 9 +- 12 files changed, 564 insertions(+), 25 deletions(-) create mode 100644 lora/models/__init__.py create mode 100644 lora/models/base.py create mode 100644 lora/models/llama.py create mode 100644 lora/models/lora.py create mode 100644 lora/models/phi2.py diff --git a/lora/README.md b/lora/README.md index a3476cbd..d9de523c 100644 --- a/lora/README.md +++ b/lora/README.md @@ -2,7 +2,7 @@ This is an example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task.[^lora] The example also supports quantized LoRA -(QLoRA).[^qlora] The example works with Llama and Mistral style +(QLoRA).[^qlora] The example works with Llama, Mistral, and Phi-2 style models available on Hugging Face. In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to @@ -81,7 +81,7 @@ To fine-tune a model use: ``` python lora.py --model \ --train \ - --iters 600 + --iters 600 \ ``` If `--model` points to a quantized model, then the training will use QLoRA, @@ -100,7 +100,7 @@ To compute test set perplexity use: ``` python lora.py --model \ --adapter-file \ - --test + --test \ ``` ### Generate @@ -114,7 +114,7 @@ python lora.py --model \ --prompt "table: 1-10015132-16 columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team Q: What is terrence ross' nationality -A: " +A: " \ ``` ## Results @@ -211,7 +211,7 @@ python lora.py \ --model mistralai/Mistral-7B-v0.1 \ --train \ --batch-size 1 \ - --lora-layers 4 + --lora-layers 4 \ ``` The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second. diff --git a/lora/convert.py b/lora/convert.py index 98697587..9b2f6de6 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -7,14 +7,16 @@ import mlx.core as mx import mlx.nn as nn import utils from mlx.utils import tree_flatten -from models import Model, ModelArgs def quantize(weights, config, args): quantized_config = copy.deepcopy(config) + # Get model classes + model_class, model_args_class = utils._get_classes(config=config) + # Load the model: - model = Model(ModelArgs.from_dict(config)) + model = model_class(model_args_class.from_dict(config)) model.load_weights(list(weights.items())) # Quantize the model: diff --git a/lora/fuse.py b/lora/fuse.py index ba34fe2a..bde543b4 100644 --- a/lora/fuse.py +++ b/lora/fuse.py @@ -4,9 +4,9 @@ import argparse from pathlib import Path import mlx.core as mx -import models import utils from mlx.utils import tree_flatten, tree_unflatten +from models.lora import LoRALinear if __name__ == "__main__": parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") @@ -45,7 +45,7 @@ if __name__ == "__main__": print("Loading pretrained model") args = parser.parse_args() - model, tokenizer, config = models.load(args.model) + model, tokenizer, config = utils.load(args.model) # Load adapters and get number of LoRA layers adapters = list(mx.load(args.adapter_file).items()) @@ -54,14 +54,14 @@ if __name__ == "__main__": # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[-lora_layers:]: - l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj) + l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) model.update(tree_unflatten(adapters)) fused_linears = [ (n, m.to_linear()) for n, m in model.named_modules() - if isinstance(m, models.LoRALinear) + if isinstance(m, LoRALinear) ] model.update_modules(tree_unflatten(fused_linears)) diff --git a/lora/lora.py b/lora/lora.py index 3f1f085f..fba22dd8 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -9,9 +9,10 @@ from pathlib import Path import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -import models import numpy as np +import utils as lora_utils from mlx.utils import tree_flatten, tree_unflatten +from models.lora import LoRALinear def build_parser(): @@ -270,7 +271,7 @@ def generate(model, prompt, tokenizer, args): tokens = [] skip = 0 for token, n in zip( - models.generate(prompt, model, args.temp), + lora_utils.generate(prompt, model, args.temp), range(args.max_tokens), ): if token == tokenizer.eos_token_id: @@ -294,13 +295,13 @@ if __name__ == "__main__": np.random.seed(args.seed) print("Loading pretrained model") - model, tokenizer, _ = models.load(args.model) + model, tokenizer, _ = lora_utils.load(args.model) # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: - l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj) + l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 print(f"Total parameters {p:.3f}M") diff --git a/lora/models/__init__.py b/lora/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lora/models/base.py b/lora/models/base.py new file mode 100644 index 00000000..d1ea0b2c --- /dev/null +++ b/lora/models/base.py @@ -0,0 +1,15 @@ +import inspect +from dataclasses import dataclass + + +@dataclass +class BaseModelArgs: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) diff --git a/lora/models/llama.py b/lora/models/llama.py new file mode 100644 index 00000000..ee026363 --- /dev/null +++ b/lora/models/llama.py @@ -0,0 +1,202 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + model_type: str = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + if self.repeats > 1: + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = LlamaModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache diff --git a/lora/models/lora.py b/lora/models/lora.py new file mode 100644 index 00000000..3f584cfd --- /dev/null +++ b/lora/models/lora.py @@ -0,0 +1,86 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class LoRALinear(nn.Module): + @staticmethod + def from_linear(linear: nn.Linear, rank: int = 8): + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + lora_lin = LoRALinear(input_dims, output_dims, rank) + lora_lin.linear = linear + return lora_lin + + def to_linear(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, nn.QuantizedLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + lora_rank: int = 8, + bias: bool = False, + scale: float = 20.0, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, lora_rank), + ) + self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) + + def __call__(self, x): + dtype = self.linear.weight.dtype + if isinstance(self.linear, nn.QuantizedLinear): + dtype = self.linear.scales.dtype + y = self.linear(x.astype(dtype)) + z = (x @ self.lora_a) @ self.lora_b + return y + self.scale * z diff --git a/lora/models/phi2.py b/lora/models/phi2.py new file mode 100644 index 00000000..51b5e390 --- /dev/null +++ b/lora/models/phi2.py @@ -0,0 +1,138 @@ +import math +from dataclasses import dataclass + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + n_positions: int = 2048 + vocab_size: int = 51200 + n_embd: int = 2560 + n_head: int = 32 + n_layer: int = 32 + rotary_dim: int = 32 + + +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, n_head: int, rotary_dim: int): + super().__init__() + + self.n_head = n_head + + self.q_proj = nn.Linear(dims, dims) + self.k_proj = nn.Linear(dims, dims) + self.v_proj = nn.Linear(dims, dims) + self.dense = nn.Linear(dims, dims) + + self.rope = nn.RoPE(rotary_dim, traditional=False) + + def __call__(self, x, mask=None, cache=None): + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Extract some shapes + n_head = self.n_head + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + + scores = mx.softmax(scores, axis=-1).astype(values.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.dense(values_hat), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, dim) + self.act = nn.GELU(approx="precise") + + def __call__(self, x) -> mx.array: + return self.fc2(self.act(self.fc1(x))) + + +class ParallelBlock(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + dims = config.n_embd + mlp_dims = dims * 4 + self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) + self.input_layernorm = LayerNorm(dims) + self.mlp = MLP(dims, mlp_dims) + + def __call__(self, x, mask, cache): + h = self.input_layernorm(x) + attn_h, cache = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + return attn_h + ff_h + x, cache + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) + self.layers = [ParallelBlock(config) for i in range(config.n_layer)] + self.final_layernorm = LayerNorm(config.n_embd) + + def __call__(self, x, mask, cache): + x = self.embed_tokens(x) + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, mask, cache[e]) + return self.final_layernorm(x), cache + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.model = Transformer(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + + def __call__( + self, + x: mx.array, + mask: mx.array = None, + cache: mx.array = None, + ) -> tuple[mx.array, mx.array]: + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.model(x, mask, cache) + return self.lm_head(y), cache diff --git a/lora/utils.py b/lora/utils.py index 182c66dc..b691227d 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -2,12 +2,44 @@ import glob import json +import logging from pathlib import Path +from typing import Generator import mlx.core as mx +import mlx.nn as nn +import models.llama as llama +import models.phi2 as phi2 import transformers from huggingface_hub import snapshot_download +# Constants +MODEL_MAPPING = { + "llama": llama, + "mistral": llama, # mistral is compatible with llama + "phi": phi2, +} + + +def _get_classes(config: dict): + """ + Retrieve the model and model args classes based on the configuration. + + Args: + config (dict): The model configuration. + + Returns: + A tuple containing the Model class and the ModelArgs class. + """ + model_type = config["model_type"] + if model_type not in MODEL_MAPPING: + msg = f"Model type {model_type} not supported." + logging.error(msg) + raise ValueError(msg) + + arch = MODEL_MAPPING[model_type] + return arch.Model, arch.ModelArgs + def fetch_from_hub(hf_path: str): model_path = snapshot_download( @@ -88,3 +120,71 @@ def save_model(save_dir: str, weights, tokenizer, config): tokenizer.save_pretrained(save_dir) with open(save_dir / "config.json", "w") as fid: json.dump(config, fid, indent=4) + + +def load(path_or_hf_repo: str): + # If the path exists, it will try to load model form it + # otherwise download and cache from the hf_repo and cache + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + ) + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + quantization = config.get("quantization", None) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model_class, model_args_class = _get_classes(config=config) + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + return model, tokenizer, config + + +def generate( + prompt: mx.array, model: nn.Module, temp: float = 0.0 +) -> Generator[mx.array, None, None]: + """ + Generate text based on the given prompt and model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + temp (float): The temperature for sampling. If temp is 0, use max sampling. + + Yields: + mx.array: The generated text. + """ + + def sample(logits: mx.array) -> mx.array: + return ( + mx.argmax(logits, axis=-1) + if temp == 0 + else mx.random.categorical(logits * (1 / temp)) + ) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 41fcd4f4..9c49e1d2 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -3,9 +3,9 @@ import argparse import mlx.core as mx +import numpy as np from PIL import Image from tqdm import tqdm -import numpy as np from stable_diffusion import StableDiffusion diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 0d21ac1e..e2e567a3 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -5,9 +5,8 @@ from pathlib import Path import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_unflatten - from huggingface_hub import snapshot_download +from mlx.utils import tree_unflatten from . import whisper @@ -18,11 +17,7 @@ def load_model( ) -> whisper.Whisper: model_path = Path(path_or_hf_repo) if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo - ) - ) + model_path = Path(snapshot_download(repo_id=path_or_hf_repo)) with open(str(model_path / "config.json"), "r") as f: config = json.loads(f.read())