Made mypy compatible

This commit is contained in:
paramthakkar123 2025-04-04 07:34:43 +05:30
parent c52cc748f8
commit d7cab9d5f5
4 changed files with 13 additions and 9 deletions

View File

@ -40,7 +40,7 @@ def generate(
if len(tokens) == 0: if len(tokens) == 0:
print("No tokens generated for this prompt") print("No tokens generated for this prompt")
return return
prompt_tps = prompt.size / prompt_time prompt_tps = len(prompt) / prompt_time
gen_tps = (len(tokens) - 1) / gen_time gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec")

View File

@ -19,10 +19,10 @@ class ModelArgs:
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
context_length: int context_length: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
model_type: str = None model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self): def __post_init__(self):
@ -54,7 +54,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
self.repeats = n_heads // n_kv_heads self.repeats = n_heads // n_kv_heads
@ -66,7 +66,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = ( rope_scale = (
1 / args.rope_scaling["factor"] 1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1 else 1
) )
@ -254,7 +254,7 @@ def translate_weight_names(name):
return name return name
def load(gguf_file: str, repo: str = None): def load(gguf_file: str, repo: Optional[str] = None):
# If the gguf_file exists, try to load model from it. # If the gguf_file exists, try to load model from it.
# Otherwise try to download and cache from the HF repo # Otherwise try to download and cache from the HF repo
if not Path(gguf_file).exists(): if not Path(gguf_file).exists():

View File

@ -7,7 +7,7 @@ import glob
import json import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Dict
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import torch import torch
@ -149,7 +149,8 @@ def quantize(weights, config, args):
def make_shards(weights: dict, max_file_size_gibibyte: int = 15): def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30 max_file_size_bytes = max_file_size_gibibyte << 30
shards = [] shards = []
shard, shard_size = {}, 0 shard : Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items(): for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes: if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard) shards.append(shard)

View File

@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int n_kv_heads: int
norm_eps: float norm_eps: float
vocab_size: int vocab_size: int
moe: dict = None moe: Optional[dict] = None
class Attention(nn.Module): class Attention(nn.Module):
@ -92,6 +92,9 @@ class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
if args.moe is None:
raise ValueError("args.moe must not be None for MOEFeedForward")
self.num_experts = args.moe["num_experts"] self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"] self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)] self.experts = [FeedForward(args) for _ in range(self.num_experts)]