version bump + some fixes (#792)

This commit is contained in:
Awni Hannun 2024-05-21 20:09:35 -07:00 committed by GitHub
parent 9f671228cd
commit 9fc6efbd90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 9 additions and 5 deletions

View File

@ -80,7 +80,7 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--lora-layers", "--lora-layers",
type=int, type=int,
help="Number of layers to fine-tune", help="Number of layers to fine-tune. Default is 16, use -1 for all.",
) )
parser.add_argument("--batch-size", type=int, help="Minibatch size.") parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.") parser.add_argument("--iters", type=int, help="Iterations to train for.")
@ -143,7 +143,7 @@ def build_parser():
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None, default=None,
) )
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument( parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune." "--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
) )
@ -268,7 +268,7 @@ def main():
config = yaml.load(file, yaml_loader) config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments # Prefer parameters from command-line arguments
for k, v in config.items(): for k, v in config.items():
if args.get(k, None) is not None: if args.get(k, None) is None:
args[k] = v args[k] = v
# Update defaults for unspecified parameters # Update defaults for unspecified parameters

View File

@ -1,4 +1,4 @@
mlx>=0.11 mlx>=0.13.1
numpy numpy
transformers>=4.39.3 transformers>=4.39.3
protobuf protobuf

View File

@ -54,6 +54,10 @@ def linear_to_lora_layers(
""" """
num_layers = len(model.layers) num_layers = len(model.layers)
if num_lora_layers < 0:
num_lora_layers = num_layers
if num_lora_layers > num_layers: if num_lora_layers > num_layers:
raise ValueError( raise ValueError(
f"Requested {num_lora_layers} LoRA layers " f"Requested {num_lora_layers} LoRA layers "

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.13.1" __version__ = "0.14.0"