chore(mlx-lm): add missing model_type for starcoder2 (#522)

This commit is contained in:
Anchen 2024-03-04 01:07:45 +11:00 committed by GitHub
parent 3655bfc3bd
commit 1e3daea3bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -167,6 +167,7 @@ class Starcoder2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Starcoder2Model(args)
# This is for 15B starcoder2 since it doesn't tie word embeddings
if not args.tie_word_embeddings: