From c2fcb6738b82d83d32dcf532d611c170cf18143e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 11:02:00 +0100 Subject: [PATCH] fix testing --- llms/mlx_lm/lora.py | 6 ++--- llms/mlx_lm/tuner/datasets.py | 42 +++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 2791db2f..4e0948d3 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -312,11 +312,10 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set else: reference_model = model - test_loss, test_rewards = evaluate_dpo( + test_loss, _, _, _ = evaluate_dpo( model=model, ref_model=reference_model, dataset=test_set, - tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.test_batches, max_seq_length=args.max_seq_length, @@ -324,7 +323,8 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set delta=args.delta, loss_type=args.dpo_loss_type, ) - print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + print(f"Test loss {test_loss:.3f}") + else: test_loss = evaluate( model=model, diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index e3d1b8fb..5a24a694 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -119,6 +119,7 @@ class CompletionsDataset: def create_dataset( + args, data, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -127,24 +128,30 @@ def create_dataset( prompt_feature = prompt_feature or "prompt" completion_feature = completion_feature or "completion" sample = data[0] - - # Add DPO dataset support - if "chosen" in sample and "rejected" in sample: - return DPODataset(data, tokenizer) - elif "messages" in sample: - return ChatDataset(data, tokenizer) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) + + if args.training_mode == "normal": + if "messages" in sample: + return ChatDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + elif "text" in sample: + return Dataset(data, tokenizer) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) + elif args.training_mode == "dpo": + if "chosen" in sample and "rejected" in sample: + return DPODataset(data, tokenizer) else: raise ValueError( - "Unsupported data format, check the supported formats here:\n" - "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + "Unsupported training mode, check the supported training modes and their formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#training-modes." ) - def load_local_dataset( + args, data_path: Path, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -155,7 +162,7 @@ def load_local_dataset( return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(args, data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -163,6 +170,7 @@ def load_local_dataset( def load_hf_dataset( + args, data_id: str, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -178,7 +186,7 @@ def load_hf_dataset( train, valid, test = [ ( create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature + args, dataset[n], tokenizer, prompt_feature, completion_feature ) if n in dataset.keys() else [] @@ -243,12 +251,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature + args, data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature + args, args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: