mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Configurable LR schedulers (#604)
* Initial config handler and test * Added means to run from CLI * Update lora config loading and tests * Constrain scheduler config (warmup and minimum LR) for each kind * Update reference to moved schedule_config module * Minor fix * Fix typos * Moved build_schedule and tests * nits in schedule config * flake * fix path --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from sys import exit
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -183,11 +183,7 @@ class PlamoModel(nn.Module):
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
|
||||
if cache is None:
|
||||
past_key_values_length = 0
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
else:
|
||||
if cache[0] is not None:
|
||||
past_key_values_length = cache[0][0].shape[2]
|
||||
|
||||
for e, layer in enumerate(self.layers.layers):
|
||||
h, c = layer(h, mask, cache[e])
|
||||
|
Reference in New Issue
Block a user