qlora
This commit is contained in:
Awni Hannun
2024-01-04 21:05:59 -08:00
committed by GitHub
parent 4fa659acbd
commit 37b41cec60
8 changed files with 137 additions and 51 deletions

View File

@@ -17,12 +17,10 @@ from sentencepiece import SentencePieceProcessor
def build_parser():
parser = argparse.ArgumentParser(
description="LoRA finetuning with Llama or Mistral"
)
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
required=True,
default="mlx_model",
help="A path to the model files containing the tokenizer, weights, config.",
)
# Generation args
@@ -332,18 +330,22 @@ def generate(model, prompt, tokenizer, args):
print(s, flush=True)
def load_model(folder: str, dtype=mx.float16):
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = Model(model_args)
if quantization is not None:
quantization["linear_class_predicate"] = lambda m: isinstance(
m, nn.Linear
) and (m.weight.shape[0] != model_args.vocab_size)
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Model(model_args)
model.update(weights)
return model, tokenizer
@@ -374,7 +376,7 @@ if __name__ == "__main__":
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file)
model.load_weights(args.resume_adapter_file, strict=False)
if args.train:
print("Training")
@@ -387,7 +389,12 @@ if __name__ == "__main__":
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
# Load the LoRA adapter weights which we assume should exist by this point
model.load_weights(args.adapter_file)
if not Path(args.adapter_file).is_file():
raise ValueError(
f"Adapter file {args.adapter_file} missing. "
"Use --train to learn and save the adapters.npz."
)
model.load_weights(args.adapter_file, strict=False)
if args.test:
print("Testing")