mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Made mypy compatible
This commit is contained in:
@@ -23,7 +23,7 @@ class ModelArgs:
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
moe: dict = None
|
||||
moe: Optional[dict] = None
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -91,6 +91,9 @@ class FeedForward(nn.Module):
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
if args.moe is None:
|
||||
raise ValueError("args.moe must not be None for MOEFeedForward")
|
||||
|
||||
self.num_experts = args.moe["num_experts"]
|
||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||
|
||||
Reference in New Issue
Block a user