mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Made mypy compatible
This commit is contained in:
parent
c52cc748f8
commit
d7cab9d5f5
@ -40,7 +40,7 @@ def generate(
|
|||||||
if len(tokens) == 0:
|
if len(tokens) == 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
return
|
return
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = len(prompt) / prompt_time
|
||||||
gen_tps = (len(tokens) - 1) / gen_time
|
gen_tps = (len(tokens) - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
@ -19,10 +19,10 @@ class ModelArgs:
|
|||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
context_length: int
|
context_length: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
model_type: str = None
|
model_type: Optional[str] = None
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -54,7 +54,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
|
||||||
|
|
||||||
self.repeats = n_heads // n_kv_heads
|
self.repeats = n_heads // n_kv_heads
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
rope_scale = (
|
rope_scale = (
|
||||||
1 / args.rope_scaling["factor"]
|
1 / float(args.rope_scaling["factor"])
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
else 1
|
else 1
|
||||||
)
|
)
|
||||||
@ -254,7 +254,7 @@ def translate_weight_names(name):
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def load(gguf_file: str, repo: str = None):
|
def load(gguf_file: str, repo: Optional[str] = None):
|
||||||
# If the gguf_file exists, try to load model from it.
|
# If the gguf_file exists, try to load model from it.
|
||||||
# Otherwise try to download and cache from the HF repo
|
# Otherwise try to download and cache from the HF repo
|
||||||
if not Path(gguf_file).exists():
|
if not Path(gguf_file).exists():
|
||||||
|
@ -7,7 +7,7 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import torch
|
import torch
|
||||||
@ -149,7 +149,8 @@ def quantize(weights, config, args):
|
|||||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||||
shards = []
|
shards = []
|
||||||
shard, shard_size = {}, 0
|
shard : Dict[str, mx.array] = {}
|
||||||
|
shard_size = 0
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if shard_size + v.nbytes > max_file_size_bytes:
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
|
@ -23,7 +23,7 @@ class ModelArgs:
|
|||||||
n_kv_heads: int
|
n_kv_heads: int
|
||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
moe: dict = None
|
moe: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -92,6 +92,9 @@ class MOEFeedForward(nn.Module):
|
|||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if args.moe is None:
|
||||||
|
raise ValueError("args.moe must not be None for MOEFeedForward")
|
||||||
|
|
||||||
self.num_experts = args.moe["num_experts"]
|
self.num_experts = args.moe["num_experts"]
|
||||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||||
|
Loading…
Reference in New Issue
Block a user