mlx-examples/llms/mlx_lm/evaluate.py

405 lines
15 KiB
Python
Raw Normal View History

# Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse
import json
import logging
import os
from importlib.metadata import version
from pathlib import Path
2025-01-23 22:36:31 +08:00
from typing import Optional
import lm_eval
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
2024-12-19 14:12:08 +08:00
from .models.base import create_causal_mask
from .models.cache import make_prompt_cache
from .utils import load, stream_generate
def _len_longest_common_prefix(a, b):
l = 0
for item_a, item_b in zip(a, b):
if item_a != item_b:
break
l += 1
return l
def _rstrip_until(s, untils):
"""Limit a string <s> to the first occurrence of any substring in untils."""
l = len(s)
f = [s.find(u) for u in untils]
f = [l if x < 0 else x for x in f]
return s[: min(f)]
2024-12-19 14:12:08 +08:00
def _pad_inputs(inputs):
2025-01-23 22:36:31 +08:00
lengths = np.array([len(x) for x in inputs])
2024-12-19 14:12:08 +08:00
maxlen = lengths.max()
2025-01-23 22:36:31 +08:00
padded = np.stack(
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
axis=0,
)
2025-01-23 22:36:31 +08:00
return mx.array(padded), mx.array(lengths)
@register_model("mlxlm")
class MLXLM(LM):
def __init__(
self,
path_or_hf_repo: str,
batch_size: int = 16,
max_tokens: Optional[int] = None,
use_chat_template: Optional[bool] = None,
) -> None:
super().__init__()
self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
2025-01-23 22:36:31 +08:00
self.use_chat_template = use_chat_template and (
self.tokenizer.chat_template is not None
)
2025-01-23 22:36:31 +08:00
def _score_fn(self, inputs, step_size: int = 64):
2024-12-19 14:12:08 +08:00
inputs, lengths = _pad_inputs(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model)
scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
2024-12-19 14:12:08 +08:00
inp = inputs[:, i : i + step_size]
T = inp.shape[1]
offset = cache[0].offset
2025-01-23 22:36:31 +08:00
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0
2024-12-19 14:12:08 +08:00
logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32))
2024-12-19 14:12:08 +08:00
score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0]
2024-12-19 14:12:08 +08:00
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
mx.eval(score, ig)
mx.metal.clear_cache()
is_greedy.append(ig)
scores.append(score)
scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1)
2024-12-19 14:12:08 +08:00
return scores, lengths, is_greedy
2024-12-19 14:12:08 +08:00
def _loglikelihood(self, texts, score_spans=None):
2025-01-23 22:36:31 +08:00
all_scores = mx.zeros(len(texts))
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
2024-12-19 14:12:08 +08:00
for i in tqdm(range(0, len(texts), self._batch_size)):
batch = texts[i : i + self._batch_size]
2025-01-23 22:36:31 +08:00
scores, lengths, is_greedy = self._score_fn(batch)
ind = np.arange(scores.shape[-1])
if score_spans is not None:
spans = score_spans[i : i + self._batch_size]
lengths = [end - start for start, end in spans]
masks = mx.array(
np.array([(ind >= start) & (ind < end) for start, end in spans])
)
else:
masks = ind[None] < lengths[:, None]
2025-01-23 22:36:31 +08:00
scores = (masks * scores).sum(axis=-1)
is_greedy = (masks * is_greedy).sum(axis=-1)
all_scores[i : i + self._batch_size] = scores
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
return all_scores, all_is_greedy
def _tokenize(self, texts):
return [
tuple(
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
)
for t in texts
]
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
# tokenize prefix and prefix + completion for all requests.
tokenized = self._tokenize(
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
)
# max length (prefix + completion) and longest common prefix per question.
length_stats = {}
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
length_stats[prefix] = (
max(max_completed_l, len(completed)),
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
)
# truncate requests for completed sequences longer than model context.
shortened = []
completion_spans = []
long_completions = 0
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, prefix_l = length_stats[prefix]
# compute truncation length
truncation = max(0, max_completed_l - self._max_tokens - 1)
prefix_l = prefix_l - truncation
if prefix_l <= 0:
# completion too long, prefix is eliminated for some requests.
long_completions += 1
truncation = max(0, len(completed) - self._max_tokens - 1)
prefix_l = 1
# truncate the completed sequence
completed = completed[truncation:]
shortened.append(completed)
# scores do not include initial bos, substract 1 to span bounds
completion_spans.append((prefix_l - 1, len(completed) - 1))
if long_completions > 0:
logging.info(
f"Prefix eliminated for {long_completions} requests with "
+ "completion longer than context."
)
2024-12-19 14:12:08 +08:00
num_results = len(shortened)
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]
2025-01-23 22:36:31 +08:00
group = mx.distributed.init()
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
2024-12-19 14:12:08 +08:00
# model scoring, returns num_requests x (logp, is_greedy, length).
2025-01-23 22:36:31 +08:00
scores, is_greedy = self._loglikelihood(
shortened,
score_spans=completion_spans,
)
2024-12-19 14:12:08 +08:00
# all gather the results across groups
2025-01-23 22:36:31 +08:00
if group.size() > 1:
2024-12-19 14:12:08 +08:00
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
2024-12-19 16:08:28 +08:00
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
2024-12-19 14:12:08 +08:00
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
2024-12-19 16:08:28 +08:00
scores = scores.T.reshape(-1)
2024-12-19 14:12:08 +08:00
is_greedy = is_greedy.T.reshape(-1)
scores = np.array(scores[:num_results])
is_greedy = np.array(is_greedy[:num_results])
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
return results
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
2025-01-23 22:36:31 +08:00
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `context` conditioned on the EOT token.
"""
logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
2024-12-19 14:12:08 +08:00
inputs = self._tokenize([req.args[0] for req in requests])
2025-01-23 22:36:31 +08:00
scores, _ = self._loglikelihood(inputs)
return scores.tolist()
def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
logging.info("Generating continuation for %d sequences." % len(requests))
contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
max_tokens = min(
self._max_tokens,
self.tokenizer.model_max_length - len(context),
)
text = ""
for response in stream_generate(
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
):
text += response.text
if any(u in text for u in until):
text = _rstrip_until(text, until)
completions.append(text)
break
else:
completions.append(text)
return completions
def main():
parser = argparse.ArgumentParser(
"Evaluate an MLX model using lm-evaluation-harness."
)
parser.add_argument("--model", help="Model to evaluate", required=True)
parser.add_argument("--tasks", nargs="+", required=True)
parser.add_argument(
"--output-dir", default=".", help="Output directory for result files."
)
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
parser.add_argument(
"--max-tokens",
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument(
"--limit",
2025-01-23 22:36:31 +08:00
default=None,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
"--fewshot-as-multiturn",
action="store_true",
help="Whether to provide the fewshot examples as a multiturn "
"conversation or a single user turn.",
default=False,
)
parser.add_argument(
"--apply-chat-template",
2025-01-23 22:36:31 +08:00
action="store_true",
help="Specifies whether to apply a chat template to the prompt.",
)
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
mx.random.seed(args.seed)
lm = MLXLM(
args.model,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_chat_template=args.apply_chat_template,
)
results = lm_eval.simple_evaluate(
model=lm,
tasks=args.tasks,
fewshot_as_multiturn=args.fewshot_as_multiturn,
apply_chat_template=lm.use_chat_template,
num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,
fewshot_random_seed=args.seed,
)
model_name = args.model.replace("/", "_")
task_names = "_".join(args.tasks)
ver = version("lm_eval")
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
output_path = output_dir / filename
output_path.write_text(json.dumps(results["results"], indent=4))
print("Results:")
for result in results["results"].values():
print(json.dumps(result, indent=4))