* add olmo2

* add olmo2
This commit is contained in:
Awni Hannun
2024-12-02 11:42:58 -08:00
committed by GitHub
parent cefe793ae0
commit 8801beb66f
3 changed files with 333 additions and 0 deletions

View File

@@ -792,6 +792,26 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_olmo2(self):
from mlx_lm.models import olmo2
args = olmo2.ModelArgs(
model_type="olmo2",
hidden_size=128,
attention_bias=False,
intermediate_size=256,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
vocab_size=1000,
)
model = olmo2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()