mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			359 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			359 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import glob
 | |
| import inspect
 | |
| import json
 | |
| import math
 | |
| from dataclasses import dataclass
 | |
| from pathlib import Path
 | |
| from typing import Dict, List, Optional, Tuple, Union
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| import numpy as np
 | |
| from huggingface_hub import snapshot_download
 | |
| from transformers import AutoTokenizer
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class ModelArgs:
 | |
|     hidden_size: int
 | |
|     num_hidden_layers: int
 | |
|     intermediate_size: int
 | |
|     num_attention_heads: int
 | |
|     rms_norm_eps: float
 | |
|     vocab_size: int
 | |
|     num_key_value_heads: int = None
 | |
|     rope_theta: float = 10000
 | |
|     rope_traditional: bool = False
 | |
|     model_type: str = None
 | |
|     rope_scaling: Optional[Dict[str, Union[float, str]]] = None
 | |
| 
 | |
|     def __post_init__(self):
 | |
|         if self.num_key_value_heads is None:
 | |
|             self.num_key_value_heads = self.num_attention_heads
 | |
| 
 | |
|         if self.rope_scaling:
 | |
|             required_keys = {"factor", "type"}
 | |
|             if not all(key in self.rope_scaling for key in required_keys):
 | |
|                 raise ValueError(f"rope_scaling must contain keys {required_keys}")
 | |
| 
 | |
|             if self.rope_scaling["type"] != "linear":
 | |
|                 raise ValueError("rope_scaling 'type' currently only supports 'linear'")
 | |
| 
 | |
|     @classmethod
 | |
|     def from_dict(cls, params):
 | |
|         return cls(
 | |
|             **{
 | |
|                 k: v
 | |
|                 for k, v in params.items()
 | |
|                 if k in inspect.signature(cls).parameters
 | |
|             }
 | |
|         )
 | |
| 
 | |
| 
 | |
| class LoRALinear(nn.Module):
 | |
|     @staticmethod
 | |
|     def from_linear(linear: nn.Linear, rank: int = 8):
 | |
|         # TODO remove when input_dims and output_dims are attributes
 | |
|         # on linear and quantized linear
 | |
|         output_dims, input_dims = linear.weight.shape
 | |
|         if isinstance(linear, nn.QuantizedLinear):
 | |
|             input_dims *= 32 // linear.bits
 | |
|         lora_lin = LoRALinear(input_dims, output_dims, rank)
 | |
|         lora_lin.linear = linear
 | |
|         return lora_lin
 | |
| 
 | |
|     def to_linear(self):
 | |
|         linear = self.linear
 | |
|         bias = "bias" in linear
 | |
|         weight = linear.weight
 | |
|         is_quantized = isinstance(linear, nn.QuantizedLinear)
 | |
| 
 | |
|         # Use the same type as the linear weight if not quantized
 | |
|         dtype = weight.dtype
 | |
| 
 | |
|         if is_quantized:
 | |
|             dtype = mx.float16
 | |
|             weight = mx.dequantize(
 | |
|                 weight,
 | |
|                 linear.scales,
 | |
|                 linear.biases,
 | |
|                 linear.group_size,
 | |
|                 linear.bits,
 | |
|             )
 | |
|         output_dims, input_dims = weight.shape
 | |
|         fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
 | |
| 
 | |
|         lora_b = (self.scale * self.lora_b.T).astype(dtype)
 | |
|         lora_a = self.lora_a.T.astype(dtype)
 | |
|         fused_linear.weight = weight + lora_b @ lora_a
 | |
|         if bias:
 | |
|             fused_linear.bias = linear.bias
 | |
| 
 | |
|         if is_quantized:
 | |
|             fused_linear = nn.QuantizedLinear.from_linear(
 | |
|                 fused_linear,
 | |
|                 linear.group_size,
 | |
|                 linear.bits,
 | |
|             )
 | |
| 
 | |
|         return fused_linear
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_dims: int,
 | |
|         output_dims: int,
 | |
|         lora_rank: int = 8,
 | |
|         bias: bool = False,
 | |
|         scale: float = 20.0,
 | |
|     ):
 | |
|         super().__init__()
 | |
| 
 | |
|         # Regular linear layer weights
 | |
|         self.linear = nn.Linear(input_dims, output_dims, bias=bias)
 | |
| 
 | |
|         # Scale for low-rank update
 | |
|         self.scale = scale
 | |
| 
 | |
|         # Low rank lora weights
 | |
|         scale = 1 / math.sqrt(input_dims)
 | |
|         self.lora_a = mx.random.uniform(
 | |
|             low=-scale,
 | |
|             high=scale,
 | |
|             shape=(input_dims, lora_rank),
 | |
|         )
 | |
|         self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
 | |
| 
 | |
|     def __call__(self, x):
 | |
|         dtype = self.linear.weight.dtype
 | |
|         if isinstance(self.linear, nn.QuantizedLinear):
 | |
|             dtype = self.linear.scales.dtype
 | |
|         y = self.linear(x.astype(dtype))
 | |
|         z = (x @ self.lora_a) @ self.lora_b
 | |
|         return y + self.scale * z
 | |
| 
 | |
| 
 | |
| class RMSNorm(nn.Module):
 | |
|     def __init__(self, dims: int, eps: float = 1e-5):
 | |
|         super().__init__()
 | |
|         self.weight = mx.ones((dims,))
 | |
|         self.eps = eps
 | |
| 
 | |
|     def _norm(self, x):
 | |
|         return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
 | |
| 
 | |
|     def __call__(self, x):
 | |
|         output = self._norm(x.astype(mx.float32)).astype(x.dtype)
 | |
|         return self.weight * output
 | |
| 
 | |
| 
 | |
| class Attention(nn.Module):
 | |
|     def __init__(self, args: ModelArgs):
 | |
|         super().__init__()
 | |
| 
 | |
|         dim = args.hidden_size
 | |
|         self.n_heads = n_heads = args.num_attention_heads
 | |
|         self.n_kv_heads = n_kv_heads = args.num_key_value_heads
 | |
| 
 | |
|         self.repeats = n_heads // n_kv_heads
 | |
| 
 | |
|         head_dim = args.hidden_size // n_heads
 | |
|         self.scale = head_dim**-0.5
 | |
| 
 | |
|         self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
 | |
|         self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
 | |
|         self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
 | |
|         self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
 | |
|         rope_scale = (
 | |
|             1 / args.rope_scaling["factor"]
 | |
|             if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
 | |
|             else 1
 | |
|         )
 | |
|         self.rope = nn.RoPE(
 | |
|             head_dim,
 | |
|             traditional=args.rope_traditional,
 | |
|             base=args.rope_theta,
 | |
|             scale=rope_scale,
 | |
|         )
 | |
| 
 | |
|     def __call__(
 | |
|         self,
 | |
|         x: mx.array,
 | |
|         mask: Optional[mx.array] = None,
 | |
|         cache: Optional[Tuple[mx.array, mx.array]] = None,
 | |
|     ) -> mx.array:
 | |
|         B, L, D = x.shape
 | |
| 
 | |
|         queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 | |
| 
 | |
|         # Prepare the queries, keys and values for the attention computation
 | |
|         queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
 | |
|         keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
 | |
|         values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
 | |
| 
 | |
|         def repeat(a):
 | |
|             a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
 | |
|             return a.reshape([B, self.n_heads, L, -1])
 | |
| 
 | |
|         if self.repeats > 1:
 | |
|             keys, values = map(repeat, (keys, values))
 | |
| 
 | |
|         if cache is not None:
 | |
|             key_cache, value_cache = cache
 | |
|             queries = self.rope(queries, offset=key_cache.shape[2])
 | |
|             keys = self.rope(keys, offset=key_cache.shape[2])
 | |
|             keys = mx.concatenate([key_cache, keys], axis=2)
 | |
|             values = mx.concatenate([value_cache, values], axis=2)
 | |
|         else:
 | |
|             queries = self.rope(queries)
 | |
|             keys = self.rope(keys)
 | |
| 
 | |
|         scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
 | |
|         if mask is not None:
 | |
|             scores += mask
 | |
|         scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
 | |
|         output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
 | |
|         return self.o_proj(output), (keys, values)
 | |
| 
 | |
| 
 | |
| class MLP(nn.Module):
 | |
|     def __init__(self, dim, hidden_dim):
 | |
|         super().__init__()
 | |
|         self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
 | |
|         self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
 | |
|         self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 | |
| 
 | |
|     def __call__(self, x) -> mx.array:
 | |
|         return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
 | |
| 
 | |
| 
 | |
| class TransformerBlock(nn.Module):
 | |
|     def __init__(self, args: ModelArgs):
 | |
|         super().__init__()
 | |
|         self.num_attention_heads = args.num_attention_heads
 | |
|         self.hidden_size = args.hidden_size
 | |
|         self.self_attn = Attention(args)
 | |
|         self.mlp = MLP(args.hidden_size, args.intermediate_size)
 | |
|         self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 | |
|         self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 | |
|         self.args = args
 | |
| 
 | |
|     def __call__(
 | |
|         self,
 | |
|         x: mx.array,
 | |
|         mask: Optional[mx.array] = None,
 | |
|         cache: Optional[Tuple[mx.array, mx.array]] = None,
 | |
|     ) -> mx.array:
 | |
|         r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
 | |
|         h = x + r
 | |
|         r = self.mlp(self.post_attention_layernorm(h))
 | |
|         out = h + r
 | |
|         return out, cache
 | |
| 
 | |
| 
 | |
| class LlamaModel(nn.Module):
 | |
|     def __init__(self, args: ModelArgs):
 | |
|         super().__init__()
 | |
|         self.args = args
 | |
|         self.vocab_size = args.vocab_size
 | |
|         self.num_hidden_layers = args.num_hidden_layers
 | |
|         assert self.vocab_size > 0
 | |
|         self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
 | |
|         self.layers = [
 | |
|             TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
 | |
|         ]
 | |
|         self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 | |
| 
 | |
|     def __call__(
 | |
|         self,
 | |
|         inputs: mx.array,
 | |
|         cache=None,
 | |
|     ):
 | |
|         h = self.embed_tokens(inputs)
 | |
| 
 | |
|         mask = None
 | |
|         if h.shape[1] > 1:
 | |
|             mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
 | |
|             mask = mask.astype(h.dtype)
 | |
| 
 | |
|         if cache is None:
 | |
|             cache = [None] * len(self.layers)
 | |
| 
 | |
|         for e, layer in enumerate(self.layers):
 | |
|             h, cache[e] = layer(h, mask, cache[e])
 | |
| 
 | |
|         return self.norm(h), cache
 | |
| 
 | |
| 
 | |
| class Model(nn.Module):
 | |
|     def __init__(self, args: ModelArgs):
 | |
|         super().__init__()
 | |
|         self.model = LlamaModel(args)
 | |
|         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
 | |
| 
 | |
|     def __call__(
 | |
|         self,
 | |
|         inputs: mx.array,
 | |
|         cache=None,
 | |
|     ):
 | |
|         out, cache = self.model(inputs, cache)
 | |
|         return self.lm_head(out), cache
 | |
| 
 | |
| 
 | |
| def load(path_or_hf_repo: str):
 | |
|     # If the path exists, it will try to load model form it
 | |
|     # otherwise download and cache from the hf_repo and cache
 | |
|     model_path = Path(path_or_hf_repo)
 | |
|     if not model_path.exists():
 | |
|         model_path = Path(
 | |
|             snapshot_download(
 | |
|                 repo_id=path_or_hf_repo,
 | |
|                 allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     with open(model_path / "config.json", "r") as f:
 | |
|         config = json.loads(f.read())
 | |
|         quantization = config.get("quantization", None)
 | |
|         model_args = ModelArgs.from_dict(config)
 | |
| 
 | |
|     weight_files = glob.glob(str(model_path / "*.safetensors"))
 | |
|     if len(weight_files) == 0:
 | |
|         raise FileNotFoundError("No safetensors found in {}".format(model_path))
 | |
| 
 | |
|     weights = {}
 | |
|     for wf in weight_files:
 | |
|         weights.update(mx.load(wf).items())
 | |
| 
 | |
|     model = Model(model_args)
 | |
|     if quantization is not None:
 | |
|         nn.QuantizedLinear.quantize_module(
 | |
|             model,
 | |
|             **quantization,
 | |
|             linear_class_predicate=lambda m: isinstance(m, nn.Linear)
 | |
|             and m.weight.shape[0] != 8,
 | |
|         )
 | |
| 
 | |
|     model.load_weights(list(weights.items()))
 | |
| 
 | |
|     mx.eval(model.parameters())
 | |
|     tokenizer = AutoTokenizer.from_pretrained(model_path)
 | |
|     return model, tokenizer, config
 | |
| 
 | |
| 
 | |
| def generate(prompt: mx.array, model: Model, temp: float = 0.0):
 | |
|     def sample(logits):
 | |
|         if temp == 0:
 | |
|             return mx.argmax(logits, axis=-1)
 | |
|         else:
 | |
|             return mx.random.categorical(logits * (1 / temp))
 | |
| 
 | |
|     y = prompt
 | |
|     cache = None
 | |
|     while True:
 | |
|         logits, cache = model(y[None], cache=cache)
 | |
|         logits = logits[:, -1, :]
 | |
|         y = sample(logits)
 | |
|         yield y
 | 
