mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Minor fixes
Fixed import path, iteration calculation, and creation of configuration namespace from YAML.
This commit is contained in:
@@ -9,7 +9,7 @@ Assumes records are of the form:
|
||||
And renders the input using the MISTRAL prompt format
|
||||
|
||||
"""
|
||||
from .supervised_lora import TrainingRecordHandler, main
|
||||
from supervised_lora import TrainingRecordHandler, main
|
||||
|
||||
MISTRAL_INSTRUCTION_PROMPT = "[INST] {prompt_input}[/INST] "
|
||||
|
||||
|
@@ -102,13 +102,15 @@ def build_parser():
|
||||
)
|
||||
parser.add_argument('filename', help="The YAML confguration file")
|
||||
with open(parser.parse_args().filename, 'r') as file:
|
||||
config = SimpleNamespace(**yaml.safe_load(file))
|
||||
if 'model' not in config.__dict__:
|
||||
raise argparse.ArgumentError('model', 'Missing required argument')
|
||||
config = yaml.safe_load(file)
|
||||
param_dict = config["parameters"]
|
||||
if 'model' not in param_dict:
|
||||
raise SyntaxError('Missing required "model" parameter')
|
||||
for key, default in CONFIG_DEFAULTS.items():
|
||||
if key not in config.__dict__:
|
||||
setattr(config, key, default)
|
||||
return config.parameters
|
||||
if key not in param_dict:
|
||||
param_dict[key] = default
|
||||
print("Configuration:", param_dict)
|
||||
return SimpleNamespace(**param_dict)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
@@ -266,7 +268,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args, handler):
|
||||
#The number of steps for 1 epoch
|
||||
epoch_num_steps = (len(train_set) + args.batch_size - 1) // args.batch_size
|
||||
|
||||
if args.epochs != -1:
|
||||
if args.epochs == -1:
|
||||
num_iterations = epoch_num_steps if args.iters == -1 else args.iters
|
||||
else:
|
||||
num_iterations = epoch_num_steps * args.epochs
|
||||
|
Reference in New Issue
Block a user