make models/phi3.py and models/phi3small.py compatible with mypy (#833)

This commit is contained in:
Yi Wang
2024-06-12 06:53:55 -07:00
committed by GitHub
parent fda41545a6
commit 6da07fb1b0
2 changed files with 15 additions and 11 deletions

View File

@@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
@@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = (64,)
blocksparse_block_size: Tuple[int] = (64,)
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@@ -58,6 +59,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@@ -157,7 +159,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r