3 Commits

Author SHA1 Message Date
Alex Barron
f787c08585 comments 2025-01-23 12:31:59 -08:00
Alex Barron
d5f49d65b9 ordering 2025-01-23 06:37:47 -08:00
Alex Barron
4385363c0f distributed evaluate 2025-01-23 06:37:45 -08:00

View File

@@ -10,7 +10,7 @@ import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional, Union
from typing import Optional
import lm_eval
import mlx.core as mx
@@ -20,11 +20,10 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
from .models.base import create_causal_mask
from .models.cache import make_prompt_cache
from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b):
l = 0
@@ -43,31 +42,14 @@ def _rstrip_until(s, untils):
return s[: min(f)]
def _pad_inputs(
inputs,
maxlen,
genlen=0,
pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
def _pad_inputs(inputs):
lengths = np.array([len(x) for x in inputs])
maxlen = lengths.max()
padded = np.stack(
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
axis=0,
)
return mx.array(padded), mx.array(lengths)
@register_model("mlxlm")
@@ -83,32 +65,33 @@ class MLXLM(LM):
self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.use_chat_template = use_chat_template and (
self.tokenizer.chat_template is not None
)
def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
def _score_fn(self, inputs, step_size: int = 64):
inputs, lengths = _pad_inputs(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
logits = self._model(inputs[:, i : i + step_size], cache=cache)
inp = inputs[:, i : i + step_size]
T = inp.shape[1]
offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0
logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0]
ig = mask[:, i : i + step_size] * (
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
mx.eval(score, ig)
mx.metal.clear_cache()
@@ -119,38 +102,32 @@ class MLXLM(LM):
scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, mask.sum(axis=-1), is_greedy
return scores, lengths, is_greedy
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
sorted_spans = None
if score_spans is not None:
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
def _loglikelihood(self, texts, score_spans=None):
all_scores = mx.zeros(len(texts))
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
for i in tqdm(range(0, len(texts), self._batch_size)):
batch = texts[i : i + self._batch_size]
scores, lengths, is_greedy = self._score_fn(batch)
results = []
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
batch = sorted_inputs[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for j in range(len(batch)):
if sorted_spans is None: # full sequence score
mask = mx.arange(scores[j].shape[-1]) < length
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
ind = np.arange(scores.shape[-1])
if score_spans is not None:
spans = score_spans[i : i + self._batch_size]
lengths = [end - start for start, end in spans]
masks = mx.array(
np.array([(ind >= start) & (ind < end) for start, end in spans])
)
else:
masks = ind[None] < lengths[:, None]
results.append((score.item(), ig.item(), length))
scores = (masks * scores).sum(axis=-1)
is_greedy = (masks * is_greedy).sum(axis=-1)
# reorder the outputs
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
all_scores[i : i + self._batch_size] = scores
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
return results
return all_scores, all_is_greedy
def _tokenize(self, texts):
return [
@@ -222,16 +199,53 @@ class MLXLM(LM):
+ "completion longer than context."
)
num_results = len(shortened)
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]
group = mx.distributed.init()
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
# model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood(
scores, is_greedy = self._loglikelihood(
shortened,
score_spans=completion_spans,
tokenize=False,
)
return [(r[0], r[1] == r[2]) for r in results]
# all gather the results across groups
if group.size() > 1:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = is_greedy.T.reshape(-1)
scores = np.array(scores[:num_results])
is_greedy = np.array(is_greedy[:num_results])
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
return results
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
@@ -268,8 +282,9 @@ class MLXLM(LM):
logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
inputs = [req.args[0] for req in requests]
return [t[0] for t in self._loglikelihood(inputs)]
inputs = self._tokenize([req.args[0] for req in requests])
scores, _ = self._loglikelihood(inputs)
return scores.tolist()
def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
@@ -332,7 +347,7 @@ def main():
)
parser.add_argument(
"--limit",
default=1.0,
default=None,
help="Limit the number of examples per task.",
type=float,
)
@@ -346,11 +361,8 @@ def main():
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
action="store_true",
help="Specifies whether to apply a chat template to the prompt.",
)
args = parser.parse_args()