mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00

* Add support for multiturn fewshot examples and chat templates Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results * Add HF overrides for methods needed by added options * don't add duplicate bos --------- Co-authored-by: Awni Hannun <awni@apple.com>
393 lines
15 KiB
Python
393 lines
15 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
"""
|
|
Adapted from a PyTorch implementation by David Grangier
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
from importlib.metadata import version
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import lm_eval
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
from lm_eval.api.model import LM
|
|
from lm_eval.api.registry import register_model
|
|
from tqdm import tqdm
|
|
|
|
from .models.cache import make_prompt_cache
|
|
from .utils import load, stream_generate
|
|
|
|
PAD = 0
|
|
|
|
|
|
def _len_longest_common_prefix(a, b):
|
|
l = 0
|
|
for item_a, item_b in zip(a, b):
|
|
if item_a != item_b:
|
|
break
|
|
l += 1
|
|
return l
|
|
|
|
|
|
def _rstrip_until(s, untils):
|
|
"""Limit a string <s> to the first occurrence of any substring in untils."""
|
|
l = len(s)
|
|
f = [s.find(u) for u in untils]
|
|
f = [l if x < 0 else x for x in f]
|
|
return s[: min(f)]
|
|
|
|
|
|
def _pad_inputs(
|
|
inputs,
|
|
maxlen,
|
|
genlen=0,
|
|
pad_left=False,
|
|
pad_multiple=32,
|
|
truncate=False,
|
|
):
|
|
# pad the prompts to the left with at least genlen tokens.
|
|
actual_maxlen = max(len(p) for p in inputs) + genlen
|
|
if actual_maxlen > maxlen:
|
|
if not truncate:
|
|
raise ValueError("Inputs are too long.")
|
|
else: # drop begining
|
|
actual_maxlen = maxlen
|
|
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
|
if pad_multiple > 0:
|
|
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
|
maxlen *= pad_multiple
|
|
assert PAD == 0
|
|
lr = np.array((1, 0) if pad_left else (0, 1))
|
|
return np.stack(
|
|
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
|
axis=0,
|
|
)
|
|
|
|
|
|
@register_model("mlxlm")
|
|
class MLXLM(LM):
|
|
def __init__(
|
|
self,
|
|
path_or_hf_repo: str,
|
|
batch_size: int = 16,
|
|
max_tokens: Optional[int] = None,
|
|
use_chat_template: Optional[bool] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self._batch_size = batch_size
|
|
self._model, self.tokenizer = load(path_or_hf_repo)
|
|
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
|
self.use_chat_template = use_chat_template or (
|
|
self.tokenizer.chat_template is not None
|
|
)
|
|
|
|
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
|
if tokenize:
|
|
inputs = self._tokenize(inputs)
|
|
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
|
inputs = mx.array(inputs)
|
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
|
|
|
cache = make_prompt_cache(self._model)
|
|
|
|
mask = targets != PAD
|
|
|
|
scores, is_greedy = [], []
|
|
for i in range(0, inputs.shape[1], step_size):
|
|
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
|
|
|
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
|
score = mx.take_along_axis(
|
|
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
|
)[..., 0]
|
|
ig = mask[:, i : i + step_size] * (
|
|
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
|
)
|
|
|
|
mx.eval(score, ig)
|
|
mx.metal.clear_cache()
|
|
|
|
is_greedy.append(ig)
|
|
scores.append(score)
|
|
|
|
scores = mx.concatenate(scores, axis=1)
|
|
is_greedy = mx.concatenate(is_greedy, axis=1)
|
|
|
|
return scores, mask.sum(axis=-1), is_greedy
|
|
|
|
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
|
# sort by length to get batches with little padding.
|
|
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
|
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
|
sorted_spans = None
|
|
if score_spans is not None:
|
|
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
|
|
|
results = []
|
|
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
|
batch = sorted_inputs[i : i + self._batch_size]
|
|
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
|
for j in range(len(batch)):
|
|
if sorted_spans is None: # full sequence score
|
|
mask = mx.arange(scores[j].shape[-1]) < length
|
|
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
|
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
|
else: # subsequence score
|
|
start, end = sorted_spans[i + j]
|
|
score = scores[j][start:end].astype(mx.float32).sum()
|
|
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
|
length = end - start
|
|
|
|
results.append((score.item(), ig.item(), length))
|
|
|
|
# reorder the outputs
|
|
inv_sort = np.argsort(sorted_indices)
|
|
results = [results[inv_sort[i]] for i in range(len(results))]
|
|
|
|
return results
|
|
|
|
def _tokenize(self, texts):
|
|
return [
|
|
tuple(
|
|
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
|
|
)
|
|
for t in texts
|
|
]
|
|
|
|
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
|
"""Compute log-likelihood of generating a continuation from a context.
|
|
Downstream tasks should attempt to use loglikelihood instead of other
|
|
LM calls whenever possible.
|
|
:param requests: list[Instance]
|
|
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
|
`context: str`
|
|
Context string. Implementations of LM must be able to handle an
|
|
empty context string.
|
|
`continuation: str`
|
|
The continuation over which log likelihood will be calculated. If
|
|
there is a word boundary, the space should be in the continuation.
|
|
For example, context="hello" continuation=" world" is correct.
|
|
:return: list[tuple[float, bool]]
|
|
A list of pairs (logprob, isgreedy)
|
|
`logprob: float`
|
|
The log probability of `continuation`.
|
|
`isgreedy`:
|
|
Whether `continuation` would be generated by greedy sampling from `context`.
|
|
"""
|
|
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
|
|
|
|
# tokenize prefix and prefix + completion for all requests.
|
|
tokenized = self._tokenize(
|
|
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
|
|
)
|
|
|
|
# max length (prefix + completion) and longest common prefix per question.
|
|
length_stats = {}
|
|
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
|
|
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
|
|
length_stats[prefix] = (
|
|
max(max_completed_l, len(completed)),
|
|
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
|
|
)
|
|
|
|
# truncate requests for completed sequences longer than model context.
|
|
shortened = []
|
|
completion_spans = []
|
|
long_completions = 0
|
|
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
|
|
max_completed_l, prefix_l = length_stats[prefix]
|
|
# compute truncation length
|
|
truncation = max(0, max_completed_l - self._max_tokens - 1)
|
|
prefix_l = prefix_l - truncation
|
|
if prefix_l <= 0:
|
|
# completion too long, prefix is eliminated for some requests.
|
|
long_completions += 1
|
|
truncation = max(0, len(completed) - self._max_tokens - 1)
|
|
prefix_l = 1
|
|
# truncate the completed sequence
|
|
completed = completed[truncation:]
|
|
shortened.append(completed)
|
|
# scores do not include initial bos, substract 1 to span bounds
|
|
completion_spans.append((prefix_l - 1, len(completed) - 1))
|
|
|
|
if long_completions > 0:
|
|
logging.info(
|
|
f"Prefix eliminated for {long_completions} requests with "
|
|
+ "completion longer than context."
|
|
)
|
|
|
|
# model scoring, returns num_requests x (logp, is_greedy, length).
|
|
results = self._loglikelihood(
|
|
shortened,
|
|
score_spans=completion_spans,
|
|
tokenize=False,
|
|
)
|
|
return [(r[0], r[1] == r[2]) for r in results]
|
|
|
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
|
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
|
|
|
def loglikelihood_rolling(self, requests) -> list[float]:
|
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
|
- We will use the full max context length of the model.
|
|
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
|
the max context length.
|
|
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
|
which may simply concatenate multiple documents together.
|
|
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
|
multiple chunks, the last input will still a full-sized context.
|
|
Example:
|
|
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
|
Prefix: EOT
|
|
Max context length: 4
|
|
Resulting input/prediction pairs:
|
|
INPUT: EOT 0 1 2
|
|
PRED: 0 1 2 3
|
|
INPUT: 3 4 5 6
|
|
PRED: 4 5 6 7
|
|
INPUT: 5 6 7 8
|
|
PRED: 8 9
|
|
Observe that:
|
|
1. Each token is predicted exactly once
|
|
2. For the last pair, we provide the full context, but only score the last two tokens
|
|
:param requests: list[Instance]
|
|
A list of Instance objects with property `args` which returns a tuple (context,).
|
|
string: str
|
|
String for which we are computing overall loglikelihood
|
|
:return: list[tuple[float]]
|
|
A list of tuples (logprob,)
|
|
logprob: float
|
|
The log probability of `context` conditioned on the EOT token.
|
|
"""
|
|
logging.info(
|
|
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
|
)
|
|
inputs = [req.args[0] for req in requests]
|
|
return [t[0] for t in self._loglikelihood(inputs)]
|
|
|
|
def generate_until(self, requests) -> list[str]:
|
|
"""Generate greedily until a stopping sequence
|
|
:param requests: list[Instance]
|
|
A list of Instance objects with property `args` which returns a tuple (context, until).
|
|
context: str
|
|
Context string
|
|
until: [str]
|
|
The string sequences to generate until. These string sequences
|
|
may each span across multiple tokens, or may be part of one token.
|
|
:return: list[str]
|
|
A list of strings continuation
|
|
continuation: str
|
|
The generated continuation.
|
|
"""
|
|
logging.info("Generating continuation for %d sequences." % len(requests))
|
|
contexts, options = zip(*[req.args for req in requests])
|
|
# contrary to the doc the second element of the tuple contains
|
|
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
|
|
keys = list(options[0].keys())
|
|
assert "until" in keys
|
|
untils = [x["until"] for x in options]
|
|
completions = []
|
|
|
|
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
|
context = self._tokenize(context)
|
|
max_tokens = min(
|
|
self._max_tokens,
|
|
self.tokenizer.model_max_length - len(context),
|
|
)
|
|
text = ""
|
|
for response in stream_generate(
|
|
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
|
|
):
|
|
text += response.text
|
|
if any(u in text for u in until):
|
|
text = _rstrip_until(text, until)
|
|
completions.append(text)
|
|
break
|
|
else:
|
|
completions.append(text)
|
|
return completions
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
"Evaluate an MLX model using lm-evaluation-harness."
|
|
)
|
|
parser.add_argument("--model", help="Model to evaluate", required=True)
|
|
parser.add_argument("--tasks", nargs="+", required=True)
|
|
parser.add_argument(
|
|
"--output-dir", default=".", help="Output directory for result files."
|
|
)
|
|
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
|
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
type=int,
|
|
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
|
)
|
|
parser.add_argument(
|
|
"--limit",
|
|
default=1.0,
|
|
help="Limit the number of examples per task.",
|
|
type=float,
|
|
)
|
|
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
|
parser.add_argument(
|
|
"--fewshot-as-multiturn",
|
|
action="store_true",
|
|
help="Whether to provide the fewshot examples as a multiturn "
|
|
"conversation or a single user turn.",
|
|
default=False,
|
|
)
|
|
parser.add_argument(
|
|
"--apply-chat-template",
|
|
action=argparse.BooleanOptionalAction,
|
|
help="Specifies whether to apply a chat template to the prompt. If "
|
|
"the model has a chat template, this defaults to `True`, "
|
|
"otherwise `False`.",
|
|
default=None,
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Silence tokenizer warnings
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
mx.random.seed(args.seed)
|
|
|
|
lm = MLXLM(
|
|
args.model,
|
|
batch_size=args.batch_size,
|
|
max_tokens=args.max_tokens,
|
|
use_chat_template=args.apply_chat_template,
|
|
)
|
|
results = lm_eval.simple_evaluate(
|
|
model=lm,
|
|
tasks=args.tasks,
|
|
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
|
apply_chat_template=lm.use_chat_template,
|
|
num_fewshot=args.num_shots,
|
|
limit=args.limit,
|
|
random_seed=args.seed,
|
|
numpy_random_seed=args.seed,
|
|
torch_random_seed=args.seed,
|
|
fewshot_random_seed=args.seed,
|
|
)
|
|
|
|
model_name = args.model.replace("/", "_")
|
|
task_names = "_".join(args.tasks)
|
|
ver = version("lm_eval")
|
|
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
|
|
output_path = output_dir / filename
|
|
output_path.write_text(json.dumps(results["results"], indent=4))
|
|
print("Results:")
|
|
for result in results["results"].values():
|
|
print(json.dumps(result, indent=4))
|