This commit is contained in:
Alex Barron 2025-01-23 06:36:31 -08:00
parent d5f49d65b9
commit f787c08585

View File

@ -10,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@ -43,13 +43,13 @@ def _rstrip_until(s, untils):
def _pad_inputs(inputs): def _pad_inputs(inputs):
lengths = mx.array([len(x) for x in inputs]) lengths = np.array([len(x) for x in inputs])
maxlen = lengths.max() maxlen = lengths.max()
padded = mx.stack( padded = np.stack(
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs], [np.pad(x, (0, maxlen - len(x))) for x in inputs],
axis=0, axis=0,
) )
return padded, lengths return mx.array(padded), mx.array(lengths)
@register_model("mlxlm") @register_model("mlxlm")
@ -65,26 +65,24 @@ class MLXLM(LM):
self._batch_size = batch_size self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo) self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or ( self.use_chat_template = use_chat_template and (
self.tokenizer.chat_template is not None self.tokenizer.chat_template is not None
) )
def _score_fn(self, inputs, step_size=64): def _score_fn(self, inputs, step_size: int = 64):
inputs, lengths = _pad_inputs(inputs) inputs, lengths = _pad_inputs(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:] inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model) cache = make_prompt_cache(self._model)
# TODO: come up with a better way to get the dtype
dtype = self._model.model.embed_tokens(inputs).dtype
scores, is_greedy = [], [] scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size): for i in range(0, inputs.shape[1], step_size):
inp = inputs[:, i : i + step_size] inp = inputs[:, i : i + step_size]
T = inp.shape[1] T = inp.shape[1]
offset = cache[0].offset offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype) mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0
logits = self._model(inp, cache=cache, mask=mask) logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32)) log_probs = nn.log_softmax(logits.astype(mx.float32))
@ -107,24 +105,29 @@ class MLXLM(LM):
return scores, lengths, is_greedy return scores, lengths, is_greedy
def _loglikelihood(self, texts, score_spans=None): def _loglikelihood(self, texts, score_spans=None):
results = [] all_scores = mx.zeros(len(texts))
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
for i in tqdm(range(0, len(texts), self._batch_size)): for i in tqdm(range(0, len(texts), self._batch_size)):
batch = texts[i : i + self._batch_size] batch = texts[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch) scores, lengths, is_greedy = self._score_fn(batch)
for j in range(len(batch)):
if score_spans is None: # full sequence score
l = length[j].item()
score = scores[j][:l].astype(mx.float32).sum()
ig = is_greedy[j][:l].astype(mx.int32).sum()
else: # subsequence score
start, end = score_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
results.append((score.item(), ig.item(), length)) ind = np.arange(scores.shape[-1])
if score_spans is not None:
spans = score_spans[i : i + self._batch_size]
lengths = [end - start for start, end in spans]
masks = mx.array(
np.array([(ind >= start) & (ind < end) for start, end in spans])
)
else:
masks = ind[None] < lengths[:, None]
return results scores = (masks * scores).sum(axis=-1)
is_greedy = (masks * is_greedy).sum(axis=-1)
all_scores[i : i + self._batch_size] = scores
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
return all_scores, all_is_greedy
def _tokenize(self, texts): def _tokenize(self, texts):
return [ return [
@ -203,23 +206,20 @@ class MLXLM(LM):
shortened = [shortened[i] for i in sorted_indices] shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices] completion_spans = [completion_spans[i] for i in sorted_indices]
group = mx.distributed.init() if mx.distributed.is_available() else None group = mx.distributed.init()
if group is not None:
# split strided so we have approximately the same lengths on each node # split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()] shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()] completion_spans = completion_spans[group.rank() :: group.size()]
# model scoring, returns num_requests x (logp, is_greedy, length). # model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood( scores, is_greedy = self._loglikelihood(
shortened, shortened,
score_spans=completion_spans, score_spans=completion_spans,
) )
scores = mx.array([r[0] for r in results])
is_greedy = mx.array([r[1] == r[2] for r in results])
# all gather the results across groups # all gather the results across groups
if group is not None: if group.size() > 1:
per_group = int(np.ceil(num_results / group.size())) per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),)) scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy)))) is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
@ -237,7 +237,15 @@ class MLXLM(LM):
return results return results
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
@ -275,7 +283,8 @@ class MLXLM(LM):
"Estimating loglikelihood rolling for %d sequences." % len(requests) "Estimating loglikelihood rolling for %d sequences." % len(requests)
) )
inputs = self._tokenize([req.args[0] for req in requests]) inputs = self._tokenize([req.args[0] for req in requests])
return [t[0] for t in self._loglikelihood(inputs)] scores, _ = self._loglikelihood(inputs)
return scores.tolist()
def generate_until(self, requests) -> list[str]: def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
@ -338,7 +347,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=1.0, default=None,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=float,
) )
@ -352,11 +361,8 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--apply-chat-template", "--apply-chat-template",
action=argparse.BooleanOptionalAction, action="store_true",
help="Specifies whether to apply a chat template to the prompt. If " help="Specifies whether to apply a chat template to the prompt.",
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
) )
args = parser.parse_args() args = parser.parse_args()