mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-04 05:28:11 +08:00 
			
		
		
		
	Added lora support for Phi-2 (#302)
* Added lora support for Phi-2 * Added Phi-2 support in fuse and convert * format + readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -2,7 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
This is an example of using MLX to fine-tune an LLM with low rank adaptation
 | 
					This is an example of using MLX to fine-tune an LLM with low rank adaptation
 | 
				
			||||||
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
 | 
					(LoRA) for a target task.[^lora] The example also supports quantized LoRA
 | 
				
			||||||
(QLoRA).[^qlora] The example works with Llama and Mistral style
 | 
					(QLoRA).[^qlora] The example works with Llama, Mistral, and Phi-2 style
 | 
				
			||||||
models available on Hugging Face.
 | 
					models available on Hugging Face.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
 | 
					In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
 | 
				
			||||||
@@ -81,7 +81,7 @@ To fine-tune a model use:
 | 
				
			|||||||
```
 | 
					```
 | 
				
			||||||
python lora.py --model <path_to_model> \
 | 
					python lora.py --model <path_to_model> \
 | 
				
			||||||
               --train \
 | 
					               --train \
 | 
				
			||||||
               --iters 600
 | 
					               --iters 600 \
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
If `--model` points to a quantized model, then the training will use QLoRA,
 | 
					If `--model` points to a quantized model, then the training will use QLoRA,
 | 
				
			||||||
@@ -100,7 +100,7 @@ To compute test set perplexity use:
 | 
				
			|||||||
```
 | 
					```
 | 
				
			||||||
python lora.py --model <path_to_model> \
 | 
					python lora.py --model <path_to_model> \
 | 
				
			||||||
               --adapter-file <path_to_adapters.npz> \
 | 
					               --adapter-file <path_to_adapters.npz> \
 | 
				
			||||||
               --test 
 | 
					               --test \
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Generate
 | 
					### Generate
 | 
				
			||||||
@@ -114,7 +114,7 @@ python lora.py --model <path_to_model> \
 | 
				
			|||||||
               --prompt "table: 1-10015132-16
 | 
					               --prompt "table: 1-10015132-16
 | 
				
			||||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
 | 
					columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
 | 
				
			||||||
Q: What is terrence ross' nationality
 | 
					Q: What is terrence ross' nationality
 | 
				
			||||||
A: "
 | 
					A: " \
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Results
 | 
					## Results
 | 
				
			||||||
@@ -211,7 +211,7 @@ python lora.py \
 | 
				
			|||||||
   --model mistralai/Mistral-7B-v0.1 \
 | 
					   --model mistralai/Mistral-7B-v0.1 \
 | 
				
			||||||
   --train \
 | 
					   --train \
 | 
				
			||||||
   --batch-size 1 \
 | 
					   --batch-size 1 \
 | 
				
			||||||
   --lora-layers 4
 | 
					   --lora-layers 4 \
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
 | 
					The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,14 +7,16 @@ import mlx.core as mx
 | 
				
			|||||||
import mlx.nn as nn
 | 
					import mlx.nn as nn
 | 
				
			||||||
import utils
 | 
					import utils
 | 
				
			||||||
from mlx.utils import tree_flatten
 | 
					from mlx.utils import tree_flatten
 | 
				
			||||||
from models import Model, ModelArgs
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def quantize(weights, config, args):
 | 
					def quantize(weights, config, args):
 | 
				
			||||||
    quantized_config = copy.deepcopy(config)
 | 
					    quantized_config = copy.deepcopy(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get model classes
 | 
				
			||||||
 | 
					    model_class, model_args_class = utils._get_classes(config=config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load the model:
 | 
					    # Load the model:
 | 
				
			||||||
    model = Model(ModelArgs.from_dict(config))
 | 
					    model = model_class(model_args_class.from_dict(config))
 | 
				
			||||||
    model.load_weights(list(weights.items()))
 | 
					    model.load_weights(list(weights.items()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Quantize the model:
 | 
					    # Quantize the model:
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										10
									
								
								lora/fuse.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								lora/fuse.py
									
									
									
									
									
								
							@@ -4,9 +4,9 @@ import argparse
 | 
				
			|||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
import models
 | 
					 | 
				
			||||||
import utils
 | 
					import utils
 | 
				
			||||||
from mlx.utils import tree_flatten, tree_unflatten
 | 
					from mlx.utils import tree_flatten, tree_unflatten
 | 
				
			||||||
 | 
					from models.lora import LoRALinear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
 | 
					    parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
 | 
				
			||||||
@@ -45,7 +45,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    print("Loading pretrained model")
 | 
					    print("Loading pretrained model")
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model, tokenizer, config = models.load(args.model)
 | 
					    model, tokenizer, config = utils.load(args.model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load adapters and get number of LoRA layers
 | 
					    # Load adapters and get number of LoRA layers
 | 
				
			||||||
    adapters = list(mx.load(args.adapter_file).items())
 | 
					    adapters = list(mx.load(args.adapter_file).items())
 | 
				
			||||||
@@ -54,14 +54,14 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Freeze all layers other than LORA linears
 | 
					    # Freeze all layers other than LORA linears
 | 
				
			||||||
    model.freeze()
 | 
					    model.freeze()
 | 
				
			||||||
    for l in model.model.layers[-lora_layers:]:
 | 
					    for l in model.model.layers[-lora_layers:]:
 | 
				
			||||||
        l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
 | 
					        l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
 | 
				
			||||||
        l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
 | 
					        l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.update(tree_unflatten(adapters))
 | 
					    model.update(tree_unflatten(adapters))
 | 
				
			||||||
    fused_linears = [
 | 
					    fused_linears = [
 | 
				
			||||||
        (n, m.to_linear())
 | 
					        (n, m.to_linear())
 | 
				
			||||||
        for n, m in model.named_modules()
 | 
					        for n, m in model.named_modules()
 | 
				
			||||||
        if isinstance(m, models.LoRALinear)
 | 
					        if isinstance(m, LoRALinear)
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.update_modules(tree_unflatten(fused_linears))
 | 
					    model.update_modules(tree_unflatten(fused_linears))
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										11
									
								
								lora/lora.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								lora/lora.py
									
									
									
									
									
								
							@@ -9,9 +9,10 @@ from pathlib import Path
 | 
				
			|||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
import mlx.nn as nn
 | 
					import mlx.nn as nn
 | 
				
			||||||
import mlx.optimizers as optim
 | 
					import mlx.optimizers as optim
 | 
				
			||||||
import models
 | 
					 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import utils as lora_utils
 | 
				
			||||||
from mlx.utils import tree_flatten, tree_unflatten
 | 
					from mlx.utils import tree_flatten, tree_unflatten
 | 
				
			||||||
 | 
					from models.lora import LoRALinear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def build_parser():
 | 
					def build_parser():
 | 
				
			||||||
@@ -270,7 +271,7 @@ def generate(model, prompt, tokenizer, args):
 | 
				
			|||||||
    tokens = []
 | 
					    tokens = []
 | 
				
			||||||
    skip = 0
 | 
					    skip = 0
 | 
				
			||||||
    for token, n in zip(
 | 
					    for token, n in zip(
 | 
				
			||||||
        models.generate(prompt, model, args.temp),
 | 
					        lora_utils.generate(prompt, model, args.temp),
 | 
				
			||||||
        range(args.max_tokens),
 | 
					        range(args.max_tokens),
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if token == tokenizer.eos_token_id:
 | 
					        if token == tokenizer.eos_token_id:
 | 
				
			||||||
@@ -294,13 +295,13 @@ if __name__ == "__main__":
 | 
				
			|||||||
    np.random.seed(args.seed)
 | 
					    np.random.seed(args.seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("Loading pretrained model")
 | 
					    print("Loading pretrained model")
 | 
				
			||||||
    model, tokenizer, _ = models.load(args.model)
 | 
					    model, tokenizer, _ = lora_utils.load(args.model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Freeze all layers other than LORA linears
 | 
					    # Freeze all layers other than LORA linears
 | 
				
			||||||
    model.freeze()
 | 
					    model.freeze()
 | 
				
			||||||
    for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
 | 
					    for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
 | 
				
			||||||
        l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
 | 
					        l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
 | 
				
			||||||
        l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
 | 
					        l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
 | 
					    p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
 | 
				
			||||||
    print(f"Total parameters {p:.3f}M")
 | 
					    print(f"Total parameters {p:.3f}M")
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										0
									
								
								lora/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								lora/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										15
									
								
								lora/models/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								lora/models/base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class BaseModelArgs:
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_dict(cls, params):
 | 
				
			||||||
 | 
					        return cls(
 | 
				
			||||||
 | 
					            **{
 | 
				
			||||||
 | 
					                k: v
 | 
				
			||||||
 | 
					                for k, v in params.items()
 | 
				
			||||||
 | 
					                if k in inspect.signature(cls).parameters
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
							
								
								
									
										202
									
								
								lora/models/llama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								lora/models/llama.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,202 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from typing import Dict, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import mlx.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .base import BaseModelArgs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class ModelArgs(BaseModelArgs):
 | 
				
			||||||
 | 
					    hidden_size: int
 | 
				
			||||||
 | 
					    num_hidden_layers: int
 | 
				
			||||||
 | 
					    intermediate_size: int
 | 
				
			||||||
 | 
					    num_attention_heads: int
 | 
				
			||||||
 | 
					    rms_norm_eps: float
 | 
				
			||||||
 | 
					    vocab_size: int
 | 
				
			||||||
 | 
					    num_key_value_heads: int = None
 | 
				
			||||||
 | 
					    rope_theta: float = 10000
 | 
				
			||||||
 | 
					    rope_traditional: bool = False
 | 
				
			||||||
 | 
					    model_type: str = None
 | 
				
			||||||
 | 
					    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __post_init__(self):
 | 
				
			||||||
 | 
					        if self.num_key_value_heads is None:
 | 
				
			||||||
 | 
					            self.num_key_value_heads = self.num_attention_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.rope_scaling:
 | 
				
			||||||
 | 
					            required_keys = {"factor", "type"}
 | 
				
			||||||
 | 
					            if not all(key in self.rope_scaling for key in required_keys):
 | 
				
			||||||
 | 
					                raise ValueError(f"rope_scaling must contain keys {required_keys}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.rope_scaling["type"] != "linear":
 | 
				
			||||||
 | 
					                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RMSNorm(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, dims: int, eps: float = 1e-5):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.weight = mx.ones((dims,))
 | 
				
			||||||
 | 
					        self.eps = eps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _norm(self, x):
 | 
				
			||||||
 | 
					        return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x):
 | 
				
			||||||
 | 
					        output = self._norm(x.astype(mx.float32)).astype(x.dtype)
 | 
				
			||||||
 | 
					        return self.weight * output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Attention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dim = args.hidden_size
 | 
				
			||||||
 | 
					        self.n_heads = n_heads = args.num_attention_heads
 | 
				
			||||||
 | 
					        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.repeats = n_heads // n_kv_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        head_dim = args.hidden_size // n_heads
 | 
				
			||||||
 | 
					        self.scale = head_dim**-0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
 | 
				
			||||||
 | 
					        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
 | 
				
			||||||
 | 
					        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
 | 
				
			||||||
 | 
					        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        rope_scale = (
 | 
				
			||||||
 | 
					            1 / args.rope_scaling["factor"]
 | 
				
			||||||
 | 
					            if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
 | 
				
			||||||
 | 
					            else 1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.rope = nn.RoPE(
 | 
				
			||||||
 | 
					            head_dim,
 | 
				
			||||||
 | 
					            traditional=args.rope_traditional,
 | 
				
			||||||
 | 
					            base=args.rope_theta,
 | 
				
			||||||
 | 
					            scale=rope_scale,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x: mx.array,
 | 
				
			||||||
 | 
					        mask: Optional[mx.array] = None,
 | 
				
			||||||
 | 
					        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
				
			||||||
 | 
					    ) -> mx.array:
 | 
				
			||||||
 | 
					        B, L, D = x.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Prepare the queries, keys and values for the attention computation
 | 
				
			||||||
 | 
					        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
 | 
				
			||||||
 | 
					        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
 | 
				
			||||||
 | 
					        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def repeat(a):
 | 
				
			||||||
 | 
					            a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
 | 
				
			||||||
 | 
					            return a.reshape([B, self.n_heads, L, -1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.repeats > 1:
 | 
				
			||||||
 | 
					            keys, values = map(repeat, (keys, values))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if cache is not None:
 | 
				
			||||||
 | 
					            key_cache, value_cache = cache
 | 
				
			||||||
 | 
					            queries = self.rope(queries, offset=key_cache.shape[2])
 | 
				
			||||||
 | 
					            keys = self.rope(keys, offset=key_cache.shape[2])
 | 
				
			||||||
 | 
					            keys = mx.concatenate([key_cache, keys], axis=2)
 | 
				
			||||||
 | 
					            values = mx.concatenate([value_cache, values], axis=2)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            queries = self.rope(queries)
 | 
				
			||||||
 | 
					            keys = self.rope(keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
 | 
				
			||||||
 | 
					        if mask is not None:
 | 
				
			||||||
 | 
					            scores += mask
 | 
				
			||||||
 | 
					        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
 | 
				
			||||||
 | 
					        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
 | 
				
			||||||
 | 
					        return self.o_proj(output), (keys, values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MLP(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, dim, hidden_dim):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
 | 
				
			||||||
 | 
					        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
 | 
				
			||||||
 | 
					        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x) -> mx.array:
 | 
				
			||||||
 | 
					        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TransformerBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.num_attention_heads = args.num_attention_heads
 | 
				
			||||||
 | 
					        self.hidden_size = args.hidden_size
 | 
				
			||||||
 | 
					        self.self_attn = Attention(args)
 | 
				
			||||||
 | 
					        self.mlp = MLP(args.hidden_size, args.intermediate_size)
 | 
				
			||||||
 | 
					        self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 | 
				
			||||||
 | 
					        self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x: mx.array,
 | 
				
			||||||
 | 
					        mask: Optional[mx.array] = None,
 | 
				
			||||||
 | 
					        cache: Optional[Tuple[mx.array, mx.array]] = None,
 | 
				
			||||||
 | 
					    ) -> mx.array:
 | 
				
			||||||
 | 
					        r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
 | 
				
			||||||
 | 
					        h = x + r
 | 
				
			||||||
 | 
					        r = self.mlp(self.post_attention_layernorm(h))
 | 
				
			||||||
 | 
					        out = h + r
 | 
				
			||||||
 | 
					        return out, cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LlamaModel(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        self.vocab_size = args.vocab_size
 | 
				
			||||||
 | 
					        self.num_hidden_layers = args.num_hidden_layers
 | 
				
			||||||
 | 
					        assert self.vocab_size > 0
 | 
				
			||||||
 | 
					        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
 | 
				
			||||||
 | 
					        self.layers = [
 | 
				
			||||||
 | 
					            TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        inputs: mx.array,
 | 
				
			||||||
 | 
					        cache=None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        h = self.embed_tokens(inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mask = None
 | 
				
			||||||
 | 
					        if h.shape[1] > 1:
 | 
				
			||||||
 | 
					            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
 | 
				
			||||||
 | 
					            mask = mask.astype(h.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if cache is None:
 | 
				
			||||||
 | 
					            cache = [None] * len(self.layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for e, layer in enumerate(self.layers):
 | 
				
			||||||
 | 
					            h, cache[e] = layer(h, mask, cache[e])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.norm(h), cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.model = LlamaModel(args)
 | 
				
			||||||
 | 
					        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        inputs: mx.array,
 | 
				
			||||||
 | 
					        cache=None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        out, cache = self.model(inputs, cache)
 | 
				
			||||||
 | 
					        return self.lm_head(out), cache
 | 
				
			||||||
							
								
								
									
										86
									
								
								lora/models/lora.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								lora/models/lora.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import mlx.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LoRALinear(nn.Module):
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def from_linear(linear: nn.Linear, rank: int = 8):
 | 
				
			||||||
 | 
					        # TODO remove when input_dims and output_dims are attributes
 | 
				
			||||||
 | 
					        # on linear and quantized linear
 | 
				
			||||||
 | 
					        output_dims, input_dims = linear.weight.shape
 | 
				
			||||||
 | 
					        if isinstance(linear, nn.QuantizedLinear):
 | 
				
			||||||
 | 
					            input_dims *= 32 // linear.bits
 | 
				
			||||||
 | 
					        lora_lin = LoRALinear(input_dims, output_dims, rank)
 | 
				
			||||||
 | 
					        lora_lin.linear = linear
 | 
				
			||||||
 | 
					        return lora_lin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_linear(self):
 | 
				
			||||||
 | 
					        linear = self.linear
 | 
				
			||||||
 | 
					        bias = "bias" in linear
 | 
				
			||||||
 | 
					        weight = linear.weight
 | 
				
			||||||
 | 
					        is_quantized = isinstance(linear, nn.QuantizedLinear)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Use the same type as the linear weight if not quantized
 | 
				
			||||||
 | 
					        dtype = weight.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_quantized:
 | 
				
			||||||
 | 
					            dtype = mx.float16
 | 
				
			||||||
 | 
					            weight = mx.dequantize(
 | 
				
			||||||
 | 
					                weight,
 | 
				
			||||||
 | 
					                linear.scales,
 | 
				
			||||||
 | 
					                linear.biases,
 | 
				
			||||||
 | 
					                linear.group_size,
 | 
				
			||||||
 | 
					                linear.bits,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        output_dims, input_dims = weight.shape
 | 
				
			||||||
 | 
					        fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lora_b = (self.scale * self.lora_b.T).astype(dtype)
 | 
				
			||||||
 | 
					        lora_a = self.lora_a.T.astype(dtype)
 | 
				
			||||||
 | 
					        fused_linear.weight = weight + lora_b @ lora_a
 | 
				
			||||||
 | 
					        if bias:
 | 
				
			||||||
 | 
					            fused_linear.bias = linear.bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_quantized:
 | 
				
			||||||
 | 
					            fused_linear = nn.QuantizedLinear.from_linear(
 | 
				
			||||||
 | 
					                fused_linear,
 | 
				
			||||||
 | 
					                linear.group_size,
 | 
				
			||||||
 | 
					                linear.bits,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return fused_linear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_dims: int,
 | 
				
			||||||
 | 
					        output_dims: int,
 | 
				
			||||||
 | 
					        lora_rank: int = 8,
 | 
				
			||||||
 | 
					        bias: bool = False,
 | 
				
			||||||
 | 
					        scale: float = 20.0,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Regular linear layer weights
 | 
				
			||||||
 | 
					        self.linear = nn.Linear(input_dims, output_dims, bias=bias)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Scale for low-rank update
 | 
				
			||||||
 | 
					        self.scale = scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Low rank lora weights
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(input_dims)
 | 
				
			||||||
 | 
					        self.lora_a = mx.random.uniform(
 | 
				
			||||||
 | 
					            low=-scale,
 | 
				
			||||||
 | 
					            high=scale,
 | 
				
			||||||
 | 
					            shape=(input_dims, lora_rank),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x):
 | 
				
			||||||
 | 
					        dtype = self.linear.weight.dtype
 | 
				
			||||||
 | 
					        if isinstance(self.linear, nn.QuantizedLinear):
 | 
				
			||||||
 | 
					            dtype = self.linear.scales.dtype
 | 
				
			||||||
 | 
					        y = self.linear(x.astype(dtype))
 | 
				
			||||||
 | 
					        z = (x @ self.lora_a) @ self.lora_b
 | 
				
			||||||
 | 
					        return y + self.scale * z
 | 
				
			||||||
							
								
								
									
										138
									
								
								lora/models/phi2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								lora/models/phi2.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,138 @@
 | 
				
			|||||||
 | 
					import math
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import mlx.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .base import BaseModelArgs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class ModelArgs(BaseModelArgs):
 | 
				
			||||||
 | 
					    n_positions: int = 2048
 | 
				
			||||||
 | 
					    vocab_size: int = 51200
 | 
				
			||||||
 | 
					    n_embd: int = 2560
 | 
				
			||||||
 | 
					    n_head: int = 32
 | 
				
			||||||
 | 
					    n_layer: int = 32
 | 
				
			||||||
 | 
					    rotary_dim: int = 32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LayerNorm(nn.LayerNorm):
 | 
				
			||||||
 | 
					    def __call__(self, x: mx.array) -> mx.array:
 | 
				
			||||||
 | 
					        return super().__call__(x.astype(mx.float32)).astype(x.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RoPEAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, dims: int, n_head: int, rotary_dim: int):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.n_head = n_head
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.q_proj = nn.Linear(dims, dims)
 | 
				
			||||||
 | 
					        self.k_proj = nn.Linear(dims, dims)
 | 
				
			||||||
 | 
					        self.v_proj = nn.Linear(dims, dims)
 | 
				
			||||||
 | 
					        self.dense = nn.Linear(dims, dims)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.rope = nn.RoPE(rotary_dim, traditional=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x, mask=None, cache=None):
 | 
				
			||||||
 | 
					        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Extract some shapes
 | 
				
			||||||
 | 
					        n_head = self.n_head
 | 
				
			||||||
 | 
					        B, L, D = queries.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Prepare the queries, keys and values for the attention computation
 | 
				
			||||||
 | 
					        queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
 | 
				
			||||||
 | 
					        keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
 | 
				
			||||||
 | 
					        values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Add RoPE to the queries and keys and combine them with the cache
 | 
				
			||||||
 | 
					        if cache is not None:
 | 
				
			||||||
 | 
					            key_cache, value_cache = cache
 | 
				
			||||||
 | 
					            queries = self.rope(queries, offset=key_cache.shape[2])
 | 
				
			||||||
 | 
					            keys = self.rope(keys, offset=key_cache.shape[2])
 | 
				
			||||||
 | 
					            keys = mx.concatenate([key_cache, keys], axis=2)
 | 
				
			||||||
 | 
					            values = mx.concatenate([value_cache, values], axis=2)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            queries = self.rope(queries)
 | 
				
			||||||
 | 
					            keys = self.rope(keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        queries = queries.astype(mx.float32)
 | 
				
			||||||
 | 
					        keys = keys.astype(mx.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Finally perform the attention computation
 | 
				
			||||||
 | 
					        scale = math.sqrt(1 / queries.shape[-1])
 | 
				
			||||||
 | 
					        scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
 | 
				
			||||||
 | 
					        if mask is not None:
 | 
				
			||||||
 | 
					            scores = scores + mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scores = mx.softmax(scores, axis=-1).astype(values.dtype)
 | 
				
			||||||
 | 
					        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.dense(values_hat), (keys, values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MLP(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, dim, hidden_dim):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.fc1 = nn.Linear(dim, hidden_dim)
 | 
				
			||||||
 | 
					        self.fc2 = nn.Linear(hidden_dim, dim)
 | 
				
			||||||
 | 
					        self.act = nn.GELU(approx="precise")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x) -> mx.array:
 | 
				
			||||||
 | 
					        return self.fc2(self.act(self.fc1(x)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ParallelBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, config: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        dims = config.n_embd
 | 
				
			||||||
 | 
					        mlp_dims = dims * 4
 | 
				
			||||||
 | 
					        self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
 | 
				
			||||||
 | 
					        self.input_layernorm = LayerNorm(dims)
 | 
				
			||||||
 | 
					        self.mlp = MLP(dims, mlp_dims)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x, mask, cache):
 | 
				
			||||||
 | 
					        h = self.input_layernorm(x)
 | 
				
			||||||
 | 
					        attn_h, cache = self.self_attn(h, mask, cache)
 | 
				
			||||||
 | 
					        ff_h = self.mlp(h)
 | 
				
			||||||
 | 
					        return attn_h + ff_h + x, cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Transformer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, config: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
 | 
				
			||||||
 | 
					        self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
 | 
				
			||||||
 | 
					        self.final_layernorm = LayerNorm(config.n_embd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, x, mask, cache):
 | 
				
			||||||
 | 
					        x = self.embed_tokens(x)
 | 
				
			||||||
 | 
					        if cache is None:
 | 
				
			||||||
 | 
					            cache = [None] * len(self.layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for e, layer in enumerate(self.layers):
 | 
				
			||||||
 | 
					            x, cache[e] = layer(x, mask, cache[e])
 | 
				
			||||||
 | 
					        return self.final_layernorm(x), cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, config: ModelArgs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.model = Transformer(config)
 | 
				
			||||||
 | 
					        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x: mx.array,
 | 
				
			||||||
 | 
					        mask: mx.array = None,
 | 
				
			||||||
 | 
					        cache: mx.array = None,
 | 
				
			||||||
 | 
					    ) -> tuple[mx.array, mx.array]:
 | 
				
			||||||
 | 
					        mask = None
 | 
				
			||||||
 | 
					        if x.shape[1] > 1:
 | 
				
			||||||
 | 
					            mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
 | 
				
			||||||
 | 
					            mask = mask.astype(x.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        y, cache = self.model(x, mask, cache)
 | 
				
			||||||
 | 
					        return self.lm_head(y), cache
 | 
				
			||||||
							
								
								
									
										100
									
								
								lora/utils.py
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								lora/utils.py
									
									
									
									
									
								
							@@ -2,12 +2,44 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import glob
 | 
					import glob
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from typing import Generator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import mlx.nn as nn
 | 
				
			||||||
 | 
					import models.llama as llama
 | 
				
			||||||
 | 
					import models.phi2 as phi2
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
from huggingface_hub import snapshot_download
 | 
					from huggingface_hub import snapshot_download
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Constants
 | 
				
			||||||
 | 
					MODEL_MAPPING = {
 | 
				
			||||||
 | 
					    "llama": llama,
 | 
				
			||||||
 | 
					    "mistral": llama,  # mistral is compatible with llama
 | 
				
			||||||
 | 
					    "phi": phi2,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_classes(config: dict):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Retrieve the model and model args classes based on the configuration.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        config (dict): The model configuration.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        A tuple containing the Model class and the ModelArgs class.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    model_type = config["model_type"]
 | 
				
			||||||
 | 
					    if model_type not in MODEL_MAPPING:
 | 
				
			||||||
 | 
					        msg = f"Model type {model_type} not supported."
 | 
				
			||||||
 | 
					        logging.error(msg)
 | 
				
			||||||
 | 
					        raise ValueError(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    arch = MODEL_MAPPING[model_type]
 | 
				
			||||||
 | 
					    return arch.Model, arch.ModelArgs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def fetch_from_hub(hf_path: str):
 | 
					def fetch_from_hub(hf_path: str):
 | 
				
			||||||
    model_path = snapshot_download(
 | 
					    model_path = snapshot_download(
 | 
				
			||||||
@@ -88,3 +120,71 @@ def save_model(save_dir: str, weights, tokenizer, config):
 | 
				
			|||||||
    tokenizer.save_pretrained(save_dir)
 | 
					    tokenizer.save_pretrained(save_dir)
 | 
				
			||||||
    with open(save_dir / "config.json", "w") as fid:
 | 
					    with open(save_dir / "config.json", "w") as fid:
 | 
				
			||||||
        json.dump(config, fid, indent=4)
 | 
					        json.dump(config, fid, indent=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load(path_or_hf_repo: str):
 | 
				
			||||||
 | 
					    # If the path exists, it will try to load model form it
 | 
				
			||||||
 | 
					    # otherwise download and cache from the hf_repo and cache
 | 
				
			||||||
 | 
					    model_path = Path(path_or_hf_repo)
 | 
				
			||||||
 | 
					    if not model_path.exists():
 | 
				
			||||||
 | 
					        model_path = Path(
 | 
				
			||||||
 | 
					            snapshot_download(
 | 
				
			||||||
 | 
					                repo_id=path_or_hf_repo,
 | 
				
			||||||
 | 
					                allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(model_path / "config.json", "r") as f:
 | 
				
			||||||
 | 
					        config = json.loads(f.read())
 | 
				
			||||||
 | 
					        quantization = config.get("quantization", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    weight_files = glob.glob(str(model_path / "*.safetensors"))
 | 
				
			||||||
 | 
					    if len(weight_files) == 0:
 | 
				
			||||||
 | 
					        raise FileNotFoundError("No safetensors found in {}".format(model_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    weights = {}
 | 
				
			||||||
 | 
					    for wf in weight_files:
 | 
				
			||||||
 | 
					        weights.update(mx.load(wf).items())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_class, model_args_class = _get_classes(config=config)
 | 
				
			||||||
 | 
					    model_args = model_args_class.from_dict(config)
 | 
				
			||||||
 | 
					    model = model_class(model_args)
 | 
				
			||||||
 | 
					    if quantization is not None:
 | 
				
			||||||
 | 
					        nn.QuantizedLinear.quantize_module(model, **quantization)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.load_weights(list(weights.items()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mx.eval(model.parameters())
 | 
				
			||||||
 | 
					    tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
 | 
				
			||||||
 | 
					    return model, tokenizer, config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def generate(
 | 
				
			||||||
 | 
					    prompt: mx.array, model: nn.Module, temp: float = 0.0
 | 
				
			||||||
 | 
					) -> Generator[mx.array, None, None]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generate text based on the given prompt and model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        prompt (mx.array): The input prompt.
 | 
				
			||||||
 | 
					        model (nn.Module): The model to use for generation.
 | 
				
			||||||
 | 
					        temp (float): The temperature for sampling. If temp is 0, use max sampling.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Yields:
 | 
				
			||||||
 | 
					        mx.array: The generated text.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(logits: mx.array) -> mx.array:
 | 
				
			||||||
 | 
					        return (
 | 
				
			||||||
 | 
					            mx.argmax(logits, axis=-1)
 | 
				
			||||||
 | 
					            if temp == 0
 | 
				
			||||||
 | 
					            else mx.random.categorical(logits * (1 / temp))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    y = prompt
 | 
				
			||||||
 | 
					    cache = None
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        logits, cache = model(y[None], cache=cache)
 | 
				
			||||||
 | 
					        logits = logits[:, -1, :]
 | 
				
			||||||
 | 
					        y = sample(logits)
 | 
				
			||||||
 | 
					        yield y
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,9 +3,9 @@
 | 
				
			|||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
from tqdm import tqdm
 | 
					from tqdm import tqdm
 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from stable_diffusion import StableDiffusion
 | 
					from stable_diffusion import StableDiffusion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,9 +5,8 @@ from pathlib import Path
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
import mlx.nn as nn
 | 
					import mlx.nn as nn
 | 
				
			||||||
from mlx.utils import tree_unflatten
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from huggingface_hub import snapshot_download
 | 
					from huggingface_hub import snapshot_download
 | 
				
			||||||
 | 
					from mlx.utils import tree_unflatten
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import whisper
 | 
					from . import whisper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -18,11 +17,7 @@ def load_model(
 | 
				
			|||||||
) -> whisper.Whisper:
 | 
					) -> whisper.Whisper:
 | 
				
			||||||
    model_path = Path(path_or_hf_repo)
 | 
					    model_path = Path(path_or_hf_repo)
 | 
				
			||||||
    if not model_path.exists():
 | 
					    if not model_path.exists():
 | 
				
			||||||
        model_path = Path(
 | 
					        model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
 | 
				
			||||||
            snapshot_download(
 | 
					 | 
				
			||||||
                repo_id=path_or_hf_repo
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(str(model_path / "config.json"), "r") as f:
 | 
					    with open(str(model_path / "config.json"), "r") as f:
 | 
				
			||||||
        config = json.loads(f.read())
 | 
					        config = json.loads(f.read())
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user