From f787c085858697326f79a5f3288ce2635eecb171 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 23 Jan 2025 06:36:31 -0800 Subject: [PATCH] comments --- llms/mlx_lm/evaluate.py | 92 ++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index 56dce27c..dc39f181 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -10,7 +10,7 @@ import logging import os from importlib.metadata import version from pathlib import Path -from typing import Optional, Union +from typing import Optional import lm_eval import mlx.core as mx @@ -43,13 +43,13 @@ def _rstrip_until(s, untils): def _pad_inputs(inputs): - lengths = mx.array([len(x) for x in inputs]) + lengths = np.array([len(x) for x in inputs]) maxlen = lengths.max() - padded = mx.stack( - [mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs], + padded = np.stack( + [np.pad(x, (0, maxlen - len(x))) for x in inputs], axis=0, ) - return padded, lengths + return mx.array(padded), mx.array(lengths) @register_model("mlxlm") @@ -65,26 +65,24 @@ class MLXLM(LM): self._batch_size = batch_size self._model, self.tokenizer = load(path_or_hf_repo) self._max_tokens = max_tokens or self.tokenizer.model_max_length - self.use_chat_template = use_chat_template or ( + self.use_chat_template = use_chat_template and ( self.tokenizer.chat_template is not None ) - def _score_fn(self, inputs, step_size=64): + def _score_fn(self, inputs, step_size: int = 64): inputs, lengths = _pad_inputs(inputs) inputs, targets = inputs[..., :-1], inputs[..., 1:] cache = make_prompt_cache(self._model) - # TODO: come up with a better way to get the dtype - dtype = self._model.model.embed_tokens(inputs).dtype - scores, is_greedy = [], [] for i in range(0, inputs.shape[1], step_size): inp = inputs[:, i : i + step_size] T = inp.shape[1] offset = cache[0].offset - mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype) + mask = create_causal_mask(T, offset, lengths=lengths) + mask = mask == 0 logits = self._model(inp, cache=cache, mask=mask) log_probs = nn.log_softmax(logits.astype(mx.float32)) @@ -107,24 +105,29 @@ class MLXLM(LM): return scores, lengths, is_greedy def _loglikelihood(self, texts, score_spans=None): - results = [] + all_scores = mx.zeros(len(texts)) + all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_) for i in tqdm(range(0, len(texts), self._batch_size)): batch = texts[i : i + self._batch_size] - scores, length, is_greedy = self._score_fn(batch) - for j in range(len(batch)): - if score_spans is None: # full sequence score - l = length[j].item() - score = scores[j][:l].astype(mx.float32).sum() - ig = is_greedy[j][:l].astype(mx.int32).sum() - else: # subsequence score - start, end = score_spans[i + j] - score = scores[j][start:end].astype(mx.float32).sum() - ig = is_greedy[j][start:end].astype(mx.int32).sum() - length = end - start + scores, lengths, is_greedy = self._score_fn(batch) - results.append((score.item(), ig.item(), length)) + ind = np.arange(scores.shape[-1]) + if score_spans is not None: + spans = score_spans[i : i + self._batch_size] + lengths = [end - start for start, end in spans] + masks = mx.array( + np.array([(ind >= start) & (ind < end) for start, end in spans]) + ) + else: + masks = ind[None] < lengths[:, None] - return results + scores = (masks * scores).sum(axis=-1) + is_greedy = (masks * is_greedy).sum(axis=-1) + + all_scores[i : i + self._batch_size] = scores + all_is_greedy[i : i + self._batch_size] = is_greedy == lengths + + return all_scores, all_is_greedy def _tokenize(self, texts): return [ @@ -203,23 +206,20 @@ class MLXLM(LM): shortened = [shortened[i] for i in sorted_indices] completion_spans = [completion_spans[i] for i in sorted_indices] - group = mx.distributed.init() if mx.distributed.is_available() else None - if group is not None: - # split strided so we have approximately the same lengths on each node - shortened = shortened[group.rank() :: group.size()] - completion_spans = completion_spans[group.rank() :: group.size()] + group = mx.distributed.init() + + # split strided so we have approximately the same lengths on each node + shortened = shortened[group.rank() :: group.size()] + completion_spans = completion_spans[group.rank() :: group.size()] # model scoring, returns num_requests x (logp, is_greedy, length). - results = self._loglikelihood( + scores, is_greedy = self._loglikelihood( shortened, score_spans=completion_spans, ) - scores = mx.array([r[0] for r in results]) - is_greedy = mx.array([r[1] == r[2] for r in results]) - # all gather the results across groups - if group is not None: + if group.size() > 1: per_group = int(np.ceil(num_results / group.size())) scores = mx.pad(scores, ((0, per_group - len(scores)),)) is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy)))) @@ -237,7 +237,15 @@ class MLXLM(LM): return results tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name - apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template + + def apply_chat_template( + self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True + ) -> str: + if len(chat_history) == 0: + return "" + return lm_eval.models.huggingface.HFLM.apply_chat_template( + chat_history, add_generation_prompt + ) def loglikelihood_rolling(self, requests) -> list[float]: """Compute full log-likelihood of a string, with no truncation, for perplexity computation @@ -275,7 +283,8 @@ class MLXLM(LM): "Estimating loglikelihood rolling for %d sequences." % len(requests) ) inputs = self._tokenize([req.args[0] for req in requests]) - return [t[0] for t in self._loglikelihood(inputs)] + scores, _ = self._loglikelihood(inputs) + return scores.tolist() def generate_until(self, requests) -> list[str]: """Generate greedily until a stopping sequence @@ -338,7 +347,7 @@ def main(): ) parser.add_argument( "--limit", - default=1.0, + default=None, help="Limit the number of examples per task.", type=float, ) @@ -352,11 +361,8 @@ def main(): ) parser.add_argument( "--apply-chat-template", - action=argparse.BooleanOptionalAction, - help="Specifies whether to apply a chat template to the prompt. If " - "the model has a chat template, this defaults to `True`, " - "otherwise `False`.", - default=None, + action="store_true", + help="Specifies whether to apply a chat template to the prompt.", ) args = parser.parse_args()