mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Add support for Llama-3.1 (#907)
* add dynamicNTK scaling rope * remove unused var * fix rope base * llama3.1 fixes * TODO for rope eval * vectorise llama3 base freq calculation * removed the arbitrary 2.0 rope_scale default case * fix slow llama3.1 generation by evaluating stateless part of DynamicNTKScalingRoPE in init * nits + format * use mx.pi * fix tests and add test for 3.1 --------- Co-authored-by: Prince Canuma <prince.gdt@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -449,6 +449,33 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_llama3_1(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
max_position_embeddings=128,
|
||||
mlp_bias=False,
|
||||
num_key_value_heads=2,
|
||||
rope_scaling={
|
||||
"factor": 8.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3",
|
||||
},
|
||||
)
|
||||
model = llama.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user