diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index aeebfc96..27a53af9 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -167,6 +167,7 @@ class Starcoder2Model(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.model = Starcoder2Model(args) # This is for 15B starcoder2 since it doesn't tie word embeddings if not args.tie_word_embeddings: