Made mypy compatible

This commit is contained in:
paramthakkar123
2025-04-04 07:34:43 +05:30
parent c52cc748f8
commit d7cab9d5f5
4 changed files with 13 additions and 9 deletions

View File

@@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int
norm_eps: float
vocab_size: int
moe: dict = None
moe: Optional[dict] = None
class Attention(nn.Module):
@@ -91,6 +91,9 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
if args.moe is None:
raise ValueError("args.moe must not be None for MOEFeedForward")
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]