diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py new file mode 100644 index 00000000..423d5823 --- /dev/null +++ b/llms/mlx_lm/evaluate.py @@ -0,0 +1,355 @@ +# Adapted from a PyTorch implementation by David Grangier + +import argparse +import json +import logging +import os +from importlib.metadata import version +from pathlib import Path +from typing import Optional + +import lm_eval +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from tqdm import tqdm + +from .models.cache import make_prompt_cache +from .utils import load, stream_generate + +PAD = 0 + + +def _len_longest_common_prefix(a, b): + l = 0 + for item_a, item_b in zip(a, b): + if item_a != item_b: + break + l += 1 + return l + + +def _rstrip_until(s, untils): + """Limit a string to the first occurence of any substring in untils.""" + l = len(s) + f = [s.find(u) for u in untils] + f = [l if x < 0 else x for x in f] + return s[: min(f)] + + +def _pad_inputs( + inputs, + maxlen, + genlen=0, + pad_left=False, + pad_multiple=32, + truncate=False, +): + # pad the prompts to the left with at least genlen tokens. + actual_maxlen = max(len(p) for p in inputs) + genlen + if actual_maxlen > maxlen: + if not truncate: + raise ValueError("Inputs are too long.") + else: # drop begining + actual_maxlen = maxlen + inputs = [p[max(0, len(p) - maxlen) :] for p in inputs] + if pad_multiple > 0: + maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple + maxlen *= pad_multiple + assert PAD == 0 + lr = np.array((1, 0) if pad_left else (0, 1)) + return np.stack( + [np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs], + axis=0, + ) + + +@register_model("mlxlm") +class MLXLM(LM): + def __init__( + self, + path_or_hf_repo: str, + batch_size: int = 16, + max_tokens: Optional[int] = None, + ) -> None: + super().__init__() + self._batch_size = batch_size + self._model, self._tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self._tokenizer.model_max_length + + def _score_fn(self, inputs, tokenize=True, step_size=32): + if tokenize: + inputs = self._tokenizer.encode(inputs) + inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) + inputs = mx.array(inputs) + inputs, targets = inputs[..., :-1], inputs[..., 1:] + + cache = make_prompt_cache(self._model) + + mask = targets != PAD + + scores, is_greedy = [], [] + for i in range(0, inputs.shape[1], step_size): + logits = self._model(inputs[:, i : i + step_size], cache=cache) + + log_probs = nn.log_softmax(logits.astype(mx.float32)) + score = mx.take_along_axis( + log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 + )[..., 0] + ig = mask[:, i : i + step_size] * ( + targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) + ) + + mx.eval(score, ig) + mx.metal.clear_cache() + + is_greedy.append(ig) + scores.append(score) + + scores = mx.concatenate(scores, axis=1) + is_greedy = mx.concatenate(is_greedy, axis=1) + + return scores, mask.sum(axis=-1), is_greedy + + def _loglikelihood(self, texts, score_spans=None, tokenize=True): + # sort by length to get batches with little padding. + sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) + sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] + sorted_spans = None + if score_spans is not None: + sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))] + + results = [] + for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): + batch = sorted_inputs[i : i + self._batch_size] + scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) + for j in range(len(batch)): + if sorted_spans is None: # full sequence score + mask = mx.arange(scores[j].shape[-1]) < length + score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) + ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) + else: # subsequence score + start, end = sorted_spans[i + j] + score = scores[j][start:end].astype(mx.float32).sum() + ig = is_greedy[j][start:end].astype(mx.int32).sum() + length = end - start + + results.append((score.item(), ig.item(), length)) + + # reorder the outputs + inv_sort = np.argsort(sorted_indices) + results = [results[inv_sort[i]] for i in range(len(results))] + + return results + + def _tokenize(self, texts): + return [tuple(self._tokenizer.encode(t)) for t in texts] + + def loglikelihood(self, requests) -> list[tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + logging.info("Estimating loglikelihood for %d pairs." % len(requests)) + + # tokenize prefix and prefix + completion for all requests. + tokenized = self._tokenize( + [t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]] + ) + + # max length (prefix + completion) and longest common prefix per question. + length_stats = {} + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8)) + length_stats[prefix] = ( + max(max_completed_l, len(completed)), + min(min_prefix_l, _len_longest_common_prefix(prefix, completed)), + ) + + # truncate requests for completed sequences longer than model context. + shortened = [] + completion_spans = [] + long_completions = 0 + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, prefix_l = length_stats[prefix] + # compute truncation length + truncation = max(0, max_completed_l - self._max_tokens - 1) + prefix_l = prefix_l - truncation + if prefix_l <= 0: + # completion too long, prefix is eliminated for some requests. + long_completions += 1 + truncation = max(0, len(completed) - self._max_tokens - 1) + prefix_l = 1 + # truncate the completed sequence + completed = completed[truncation:] + shortened.append(completed) + # scores do not include initial bos, substract 1 to span bounds + completion_spans.append((prefix_l - 1, len(completed) - 1)) + + if long_completions > 0: + logging.info( + f"Prefix eliminated for {long_completions} requests with " + + "completion longer than context." + ) + + # model scoring, returns num_requests x (logp, is_greedy, length). + results = self._loglikelihood( + shortened, + score_spans=completion_spans, + tokenize=False, + ) + return [(r[0], r[1] == r[2]) for r in results] + + def loglikelihood_rolling(self, requests) -> list[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: EOT + Max context length: 4 + Resulting input/prediction pairs: + INPUT: EOT 0 1 2 + PRED: 0 1 2 3 + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + INPUT: 5 6 7 8 + PRED: 8 9 + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the EOT token. + """ + logging.info( + "Estimating loglikelihood rolling for %d sequences." % len(requests) + ) + inputs = [req.args[0] for req in requests] + return [t[0] for t in self._loglikelihood(inputs)] + + def generate_until(self, requests) -> list[str]: + """Generate greedily until a stopping sequence + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, until). + context: str + Context string + until: [str] + The string sequences to generate until. These string sequences + may each span across multiple tokens, or may be part of one token. + :return: list[str] + A list of strings continuation + continuation: str + The generated continuation. + """ + logging.info("Generating continuation for %d sequences." % len(requests)) + contexts, options = zip(*[req.args for req in requests]) + # contrary to the doc the second element of the tuple contains + # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} + keys = list(options[0].keys()) + assert "until" in keys + untils = [x["until"] for x in options] + completions = [] + for context, until in tqdm(zip(contexts, untils), total=len(contexts)): + if ( + hasattr(self._tokenizer, "apply_chat_template") + and self._tokenizer.chat_template is not None + ): + messages = [{"role": "user", "content": context}] + context = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + max_tokens = min( + self._max_tokens, + self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), + ) + text = "" + for response in stream_generate( + self._model, self._tokenizer, prompt=context, max_tokens=max_tokens + ): + text += response.text + if any(u in text for u in until): + text = _rstrip_until(text, until) + completions.append(text) + break + else: + completions.append(text) + return completions + + +def main(): + parser = argparse.ArgumentParser( + "Evaluate an MLX model using lm-evaluation-harness." + ) + parser.add_argument("--model", help="Model to evaluate", required=True) + parser.add_argument("--tasks", nargs="+", required=True) + parser.add_argument( + "--output-dir", default=".", help="Output directory for result files." + ) + parser.add_argument("--batch-size", type=int, default=16, help="Batch size") + parser.add_argument("--num-shots", type=int, default=0, help="Number of shots") + parser.add_argument( + "--max-tokens", + type=int, + help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", + ) + parser.add_argument("--seed", type=int, default=123, help="Random seed.") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Silence tokenizer warnings + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + mx.random.seed(args.seed) + + lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) + + results = lm_eval.simple_evaluate( + model=lm, + tasks=args.tasks, + num_fewshot=args.num_shots, + random_seed=args.seed, + numpy_random_seed=args.seed, + torch_random_seed=args.seed, + fewshot_random_seed=args.seed, + ) + + model_name = args.model.replace("/", "_") + task_names = "_".join(args.tasks) + ver = version("lm_eval") + filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json" + output_path = output_dir / filename + output_path.write_text(json.dumps(results["results"], indent=4)) + print("Results:") + for result in results["results"].values(): + print(json.dumps(result, indent=4)) diff --git a/llms/setup.py b/llms/setup.py index 1c696dc0..b88dcd33 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -28,12 +28,14 @@ setup( python_requires=">=3.8", extras_require={ "testing": ["datasets"], + "evaluation": ["lm-eval"], }, entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", + "mlx_lm.evaluate = mlx_lm.evaluate:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", "mlx_lm.lora = mlx_lm.lora:main",