From eda597bdef485fa29d4f139cec0cda5c5f62748b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 9 Feb 2025 19:37:11 -0800 Subject: [PATCH] simplify --- llms/mlx_lm/tuner/datasets.py | 82 +++++++++++------------------------ 1 file changed, 25 insertions(+), 57 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 44c78450..a6f3bd29 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,5 +1,6 @@ import itertools import json +import types from pathlib import Path from typing import Any, Dict, List, Optional @@ -115,22 +116,26 @@ class ConcatenatedDataset: def create_dataset( data, tokenizer: PreTrainedTokenizer, - config: Dict, + config, ): mask_prompt = getattr(config, "mask_prompt", False) prompt_feature = getattr(config, "prompt_feature", "prompt") + text_feature = getattr(config, "text_feature", "text") completion_feature = getattr(config, "completion_feature", "completion") + chat_feature = getattr(config, "chat_feature", "messages") sample = data[0] - if "messages" in sample: - return ChatDataset(data, tokenizer, mask_prompt=mask_prompt) - elif prompt_feature in sample and completion_feature in sample: + if prompt_feature in sample and completion_feature in sample: return CompletionsDataset( data, tokenizer, prompt_feature, completion_feature, mask_prompt ) - elif "text" in sample: + elif chat_feature in sample: + return ChatDataset( + data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt + ) + elif text_feature in sample: if mask_prompt: raise ValueError("Prompt masking not supported for text dataset.") - return Dataset(data, tokenizer) + return Dataset(data, tokenizer, text_key=text_feature) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -141,7 +146,7 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - config: Dict, + config, ): def load_subset(path): if not path.exists(): @@ -158,7 +163,7 @@ def load_local_dataset( def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - config: Dict, + config, ): from datasets import exceptions, load_dataset @@ -185,39 +190,13 @@ def load_hf_dataset( def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - mask_prompt = getattr(args, "mask_prompt", False) - - def create_hf_dataset( - dataset_name, - text_feature, - prompt_feature, - completion_feature, - chat_feature, - split, - config, - ): + def create_hf_dataset(dataset_name, config, split, hf_config): ds = datasets.load_dataset( dataset_name, split=split, - **config, + **hf_config, ) - if prompt_feature and completion_feature: - return CompletionsDataset( - ds, tokenizer, prompt_feature, completion_feature, mask_prompt - ) - elif chat_feature: - return ChatDataset( - ds, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt - ) - elif text_feature: - if mask_prompt: - raise ValueError("Prompt masking not supported for text dataset.") - return Dataset(ds, tokenizer, text_key=text_feature) - else: - raise ValueError( - "Specify either a prompt and completion feature, a chat feature," - " or a text feature for the Hugging Face dataset." - ) + return create_dataset(ds, tokenizer, config) dataset_collection = args.hf_dataset if isinstance(dataset_collection, dict): @@ -227,31 +206,23 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): for ds in dataset_collection: ds_name = ds["name"] print(f"Loading Hugging Face dataset {ds_name}.") - text_f = ds.get("text_feature", None) - prompt_f = ds.get("prompt_feature", None) - completion_f = ds.get("completion_feature", None) - chat_f = ds.get("chat_feature", None) - ds_config = ds.get("config", {}) + ds["mask_prompt"] = getattr(args, "mask_prompt", False) + config = types.SimpleNamespace(**ds) + hf_config = ds.get("config", {}) if args.train: train_split = ds.get("train_split", "train[:80%]") valid_split = ds.get("valid_split", "train[-10%:]") train = create_hf_dataset( ds_name, - text_f, - prompt_f, - completion_f, - chat_f, + config, train_split, - ds_config, + hf_config, ) valid = create_hf_dataset( ds_name, - text_f, - prompt_f, - completion_f, - chat_f, + config, valid_split, - ds_config, + hf_config, ) else: train, valid = [], [] @@ -260,12 +231,9 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): test_split = ds.get("test_split") test = create_hf_dataset( ds_name, - text_f, - prompt_f, - completion_f, - chat_f, + config, test_split, - ds_config, + hf_config, ) else: test = []