3 Commits

Author SHA1 Message Date
Alex Barron
f787c08585 comments 2025-01-23 12:31:59 -08:00
Alex Barron
d5f49d65b9 ordering 2025-01-23 06:37:47 -08:00
Alex Barron
4385363c0f distributed evaluate 2025-01-23 06:37:45 -08:00

View File

@@ -10,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@@ -20,11 +20,10 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
from .models.base import create_causal_mask
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .utils import load, stream_generate from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b): def _len_longest_common_prefix(a, b):
l = 0 l = 0
@@ -43,31 +42,14 @@ def _rstrip_until(s, untils):
return s[: min(f)] return s[: min(f)]
def _pad_inputs( def _pad_inputs(inputs):
inputs, lengths = np.array([len(x) for x in inputs])
maxlen, maxlen = lengths.max()
genlen=0, padded = np.stack(
pad_left=False, [np.pad(x, (0, maxlen - len(x))) for x in inputs],
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
axis=0, axis=0,
) )
return mx.array(padded), mx.array(lengths)
@register_model("mlxlm") @register_model("mlxlm")
@@ -83,32 +65,33 @@ class MLXLM(LM):
self._batch_size = batch_size self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo) self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or ( self.use_chat_template = use_chat_template and (
self.tokenizer.chat_template is not None self.tokenizer.chat_template is not None
) )
def _score_fn(self, inputs, tokenize=True, step_size=32): def _score_fn(self, inputs, step_size: int = 64):
if tokenize: inputs, lengths = _pad_inputs(inputs)
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:] inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model) cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], [] scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size): for i in range(0, inputs.shape[1], step_size):
logits = self._model(inputs[:, i : i + step_size], cache=cache) inp = inputs[:, i : i + step_size]
T = inp.shape[1]
offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0
logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32)) log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis( score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0] )[..., 0]
ig = mask[:, i : i + step_size] * ( ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
)
mx.eval(score, ig) mx.eval(score, ig)
mx.metal.clear_cache() mx.metal.clear_cache()
@@ -119,38 +102,32 @@ class MLXLM(LM):
scores = mx.concatenate(scores, axis=1) scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1) is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, mask.sum(axis=-1), is_greedy return scores, lengths, is_greedy
def _loglikelihood(self, texts, score_spans=None, tokenize=True): def _loglikelihood(self, texts, score_spans=None):
# sort by length to get batches with little padding. all_scores = mx.zeros(len(texts))
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] for i in tqdm(range(0, len(texts), self._batch_size)):
sorted_spans = None batch = texts[i : i + self._batch_size]
if score_spans is not None: scores, lengths, is_greedy = self._score_fn(batch)
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
results = [] ind = np.arange(scores.shape[-1])
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): if score_spans is not None:
batch = sorted_inputs[i : i + self._batch_size] spans = score_spans[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) lengths = [end - start for start, end in spans]
for j in range(len(batch)): masks = mx.array(
if sorted_spans is None: # full sequence score np.array([(ind >= start) & (ind < end) for start, end in spans])
mask = mx.arange(scores[j].shape[-1]) < length )
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) else:
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) masks = ind[None] < lengths[:, None]
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
results.append((score.item(), ig.item(), length)) scores = (masks * scores).sum(axis=-1)
is_greedy = (masks * is_greedy).sum(axis=-1)
# reorder the outputs all_scores[i : i + self._batch_size] = scores
inv_sort = np.argsort(sorted_indices) all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
results = [results[inv_sort[i]] for i in range(len(results))]
return results return all_scores, all_is_greedy
def _tokenize(self, texts): def _tokenize(self, texts):
return [ return [
@@ -222,16 +199,53 @@ class MLXLM(LM):
+ "completion longer than context." + "completion longer than context."
) )
num_results = len(shortened)
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]
group = mx.distributed.init()
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
# model scoring, returns num_requests x (logp, is_greedy, length). # model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood( scores, is_greedy = self._loglikelihood(
shortened, shortened,
score_spans=completion_spans, score_spans=completion_spans,
tokenize=False,
) )
return [(r[0], r[1] == r[2]) for r in results]
# all gather the results across groups
if group.size() > 1:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = is_greedy.T.reshape(-1)
scores = np.array(scores[:num_results])
is_greedy = np.array(is_greedy[:num_results])
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
return results
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
@@ -268,8 +282,9 @@ class MLXLM(LM):
logging.info( logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests) "Estimating loglikelihood rolling for %d sequences." % len(requests)
) )
inputs = [req.args[0] for req in requests] inputs = self._tokenize([req.args[0] for req in requests])
return [t[0] for t in self._loglikelihood(inputs)] scores, _ = self._loglikelihood(inputs)
return scores.tolist()
def generate_until(self, requests) -> list[str]: def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
@@ -332,7 +347,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=1.0, default=None,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=float,
) )
@@ -346,11 +361,8 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--apply-chat-template", "--apply-chat-template",
action=argparse.BooleanOptionalAction, action="store_true",
help="Specifies whether to apply a chat template to the prompt. If " help="Specifies whether to apply a chat template to the prompt.",
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
) )
args = parser.parse_args() args = parser.parse_args()