From f2619f507c7dcde70410cc2cbb1d4715476d79ee Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 6 Jan 2025 10:58:43 -0500 Subject: [PATCH] Add support for fewshot and apply chat template lm_eval functionality (#1180) * Add support for multiturn fewshot examples and chat templates Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results * Add HF overrides for methods needed by added options * don't add duplicate bos --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 +- llms/mlx_lm/evaluate.py | 59 +++++++++++++++++++++++++++++------------ llms/setup.py | 4 +-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cecd2d57..8367281e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade pip pip install unittest-xml-reporting cd llms/ - pip install -e ".[testing]" + pip install -e ".[test]" - run: name: Run Python tests command: | diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index bf7bf4d4..ca5e83bb 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -77,15 +77,19 @@ class MLXLM(LM): path_or_hf_repo: str, batch_size: int = 16, max_tokens: Optional[int] = None, + use_chat_template: Optional[bool] = None, ) -> None: super().__init__() self._batch_size = batch_size - self._model, self._tokenizer = load(path_or_hf_repo) - self._max_tokens = max_tokens or self._tokenizer.model_max_length + self._model, self.tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self.tokenizer.model_max_length + self.use_chat_template = use_chat_template or ( + self.tokenizer.chat_template is not None + ) def _score_fn(self, inputs, tokenize=True, step_size=32): if tokenize: - inputs = self._tokenizer.encode(inputs) + inputs = self._tokenize(inputs) inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) inputs = mx.array(inputs) inputs, targets = inputs[..., :-1], inputs[..., 1:] @@ -149,7 +153,12 @@ class MLXLM(LM): return results def _tokenize(self, texts): - return [tuple(self._tokenizer.encode(t)) for t in texts] + return [ + tuple( + self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template) + ) + for t in texts + ] def loglikelihood(self, requests) -> list[tuple[float, bool]]: """Compute log-likelihood of generating a continuation from a context. @@ -221,6 +230,9 @@ class MLXLM(LM): ) return [(r[0], r[1] == r[2]) for r in results] + tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name + apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template + def loglikelihood_rolling(self, requests) -> list[float]: """Compute full log-likelihood of a string, with no truncation, for perplexity computation - We will use the full max context length of the model. @@ -283,21 +295,14 @@ class MLXLM(LM): completions = [] for context, until in tqdm(zip(contexts, untils), total=len(contexts)): - if self._tokenizer.chat_template is not None: - messages = [{"role": "user", "content": context}] - context = self._tokenizer.apply_chat_template( - messages, add_generation_prompt=True - ) - else: - context = self._tokenizer.encode(context) - + context = self._tokenize(context) max_tokens = min( self._max_tokens, - self._tokenizer.model_max_length - len(context), + self.tokenizer.model_max_length - len(context), ) text = "" for response in stream_generate( - self._model, self._tokenizer, prompt=context, max_tokens=max_tokens + self._model, self.tokenizer, prompt=context, max_tokens=max_tokens ): text += response.text if any(u in text for u in until): @@ -332,6 +337,21 @@ def main(): type=float, ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") + parser.add_argument( + "--fewshot-as-multiturn", + action="store_true", + help="Whether to provide the fewshot examples as a multiturn " + "conversation or a single user turn.", + default=False, + ) + parser.add_argument( + "--apply-chat-template", + action=argparse.BooleanOptionalAction, + help="Specifies whether to apply a chat template to the prompt. If " + "the model has a chat template, this defaults to `True`, " + "otherwise `False`.", + default=None, + ) args = parser.parse_args() output_dir = Path(args.output_dir) @@ -342,18 +362,23 @@ def main(): mx.random.seed(args.seed) - lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) - + lm = MLXLM( + args.model, + batch_size=args.batch_size, + max_tokens=args.max_tokens, + use_chat_template=args.apply_chat_template, + ) results = lm_eval.simple_evaluate( model=lm, tasks=args.tasks, + fewshot_as_multiturn=args.fewshot_as_multiturn, + apply_chat_template=lm.use_chat_template, num_fewshot=args.num_shots, limit=args.limit, random_seed=args.seed, numpy_random_seed=args.seed, torch_random_seed=args.seed, fewshot_random_seed=args.seed, - apply_chat_template=True, ) model_name = args.model.replace("/", "_") diff --git a/llms/setup.py b/llms/setup.py index b88dcd33..e6fddbae 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -27,8 +27,8 @@ setup( packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", extras_require={ - "testing": ["datasets"], - "evaluation": ["lm-eval"], + "test": ["datasets"], + "evaluate": ["lm-eval", "tqdm"], }, entry_points={ "console_scripts": [