Minor fixes

Fixed import path, iteration calculation, and creation of configuration namespace from YAML.
This commit is contained in:
Chime Ogbuji
2024-01-05 12:22:00 -05:00
parent 0dbc78ca65
commit 119cadbd09
2 changed files with 10 additions and 8 deletions

View File

@@ -9,7 +9,7 @@ Assumes records are of the form:
And renders the input using the MISTRAL prompt format
"""
from .supervised_lora import TrainingRecordHandler, main
from supervised_lora import TrainingRecordHandler, main
MISTRAL_INSTRUCTION_PROMPT = "[INST] {prompt_input}[/INST] "

View File

@@ -102,13 +102,15 @@ def build_parser():
)
parser.add_argument('filename', help="The YAML confguration file")
with open(parser.parse_args().filename, 'r') as file:
config = SimpleNamespace(**yaml.safe_load(file))
if 'model' not in config.__dict__:
raise argparse.ArgumentError('model', 'Missing required argument')
config = yaml.safe_load(file)
param_dict = config["parameters"]
if 'model' not in param_dict:
raise SyntaxError('Missing required "model" parameter')
for key, default in CONFIG_DEFAULTS.items():
if key not in config.__dict__:
setattr(config, key, default)
return config.parameters
if key not in param_dict:
param_dict[key] = default
print("Configuration:", param_dict)
return SimpleNamespace(**param_dict)
class Tokenizer:
@@ -266,7 +268,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args, handler):
#The number of steps for 1 epoch
epoch_num_steps = (len(train_set) + args.batch_size - 1) // args.batch_size
if args.epochs != -1:
if args.epochs == -1:
num_iterations = epoch_num_steps if args.iters == -1 else args.iters
else:
num_iterations = epoch_num_steps * args.epochs