Configurable LR schedulers (#604)

* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-29 16:41:10 -04:00
committed by GitHub
parent b80adbcc3e
commit f6283ef7ce
7 changed files with 93 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -183,11 +183,7 @@ class PlamoModel(nn.Module):
mask = mask.astype(self.embed_tokens.weight.dtype)
if cache is None:
past_key_values_length = 0
cache = [None for _ in range(len(self.layers.layers))]
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, cache[e])