mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
comments
This commit is contained in:
parent
d5f49d65b9
commit
f787c08585
@ -10,7 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import lm_eval
|
import lm_eval
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -43,13 +43,13 @@ def _rstrip_until(s, untils):
|
|||||||
|
|
||||||
|
|
||||||
def _pad_inputs(inputs):
|
def _pad_inputs(inputs):
|
||||||
lengths = mx.array([len(x) for x in inputs])
|
lengths = np.array([len(x) for x in inputs])
|
||||||
maxlen = lengths.max()
|
maxlen = lengths.max()
|
||||||
padded = mx.stack(
|
padded = np.stack(
|
||||||
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs],
|
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
return padded, lengths
|
return mx.array(padded), mx.array(lengths)
|
||||||
|
|
||||||
|
|
||||||
@register_model("mlxlm")
|
@register_model("mlxlm")
|
||||||
@ -65,26 +65,24 @@ class MLXLM(LM):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||||
self.use_chat_template = use_chat_template or (
|
self.use_chat_template = use_chat_template and (
|
||||||
self.tokenizer.chat_template is not None
|
self.tokenizer.chat_template is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
def _score_fn(self, inputs, step_size=64):
|
def _score_fn(self, inputs, step_size: int = 64):
|
||||||
inputs, lengths = _pad_inputs(inputs)
|
inputs, lengths = _pad_inputs(inputs)
|
||||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||||
|
|
||||||
cache = make_prompt_cache(self._model)
|
cache = make_prompt_cache(self._model)
|
||||||
|
|
||||||
# TODO: come up with a better way to get the dtype
|
|
||||||
dtype = self._model.model.embed_tokens(inputs).dtype
|
|
||||||
|
|
||||||
scores, is_greedy = [], []
|
scores, is_greedy = [], []
|
||||||
for i in range(0, inputs.shape[1], step_size):
|
for i in range(0, inputs.shape[1], step_size):
|
||||||
inp = inputs[:, i : i + step_size]
|
inp = inputs[:, i : i + step_size]
|
||||||
T = inp.shape[1]
|
T = inp.shape[1]
|
||||||
|
|
||||||
offset = cache[0].offset
|
offset = cache[0].offset
|
||||||
mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype)
|
mask = create_causal_mask(T, offset, lengths=lengths)
|
||||||
|
mask = mask == 0
|
||||||
|
|
||||||
logits = self._model(inp, cache=cache, mask=mask)
|
logits = self._model(inp, cache=cache, mask=mask)
|
||||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||||
@ -107,24 +105,29 @@ class MLXLM(LM):
|
|||||||
return scores, lengths, is_greedy
|
return scores, lengths, is_greedy
|
||||||
|
|
||||||
def _loglikelihood(self, texts, score_spans=None):
|
def _loglikelihood(self, texts, score_spans=None):
|
||||||
results = []
|
all_scores = mx.zeros(len(texts))
|
||||||
|
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
|
||||||
for i in tqdm(range(0, len(texts), self._batch_size)):
|
for i in tqdm(range(0, len(texts), self._batch_size)):
|
||||||
batch = texts[i : i + self._batch_size]
|
batch = texts[i : i + self._batch_size]
|
||||||
scores, length, is_greedy = self._score_fn(batch)
|
scores, lengths, is_greedy = self._score_fn(batch)
|
||||||
for j in range(len(batch)):
|
|
||||||
if score_spans is None: # full sequence score
|
|
||||||
l = length[j].item()
|
|
||||||
score = scores[j][:l].astype(mx.float32).sum()
|
|
||||||
ig = is_greedy[j][:l].astype(mx.int32).sum()
|
|
||||||
else: # subsequence score
|
|
||||||
start, end = score_spans[i + j]
|
|
||||||
score = scores[j][start:end].astype(mx.float32).sum()
|
|
||||||
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
|
||||||
length = end - start
|
|
||||||
|
|
||||||
results.append((score.item(), ig.item(), length))
|
ind = np.arange(scores.shape[-1])
|
||||||
|
if score_spans is not None:
|
||||||
|
spans = score_spans[i : i + self._batch_size]
|
||||||
|
lengths = [end - start for start, end in spans]
|
||||||
|
masks = mx.array(
|
||||||
|
np.array([(ind >= start) & (ind < end) for start, end in spans])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
masks = ind[None] < lengths[:, None]
|
||||||
|
|
||||||
return results
|
scores = (masks * scores).sum(axis=-1)
|
||||||
|
is_greedy = (masks * is_greedy).sum(axis=-1)
|
||||||
|
|
||||||
|
all_scores[i : i + self._batch_size] = scores
|
||||||
|
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
|
||||||
|
|
||||||
|
return all_scores, all_is_greedy
|
||||||
|
|
||||||
def _tokenize(self, texts):
|
def _tokenize(self, texts):
|
||||||
return [
|
return [
|
||||||
@ -203,23 +206,20 @@ class MLXLM(LM):
|
|||||||
shortened = [shortened[i] for i in sorted_indices]
|
shortened = [shortened[i] for i in sorted_indices]
|
||||||
completion_spans = [completion_spans[i] for i in sorted_indices]
|
completion_spans = [completion_spans[i] for i in sorted_indices]
|
||||||
|
|
||||||
group = mx.distributed.init() if mx.distributed.is_available() else None
|
group = mx.distributed.init()
|
||||||
if group is not None:
|
|
||||||
# split strided so we have approximately the same lengths on each node
|
# split strided so we have approximately the same lengths on each node
|
||||||
shortened = shortened[group.rank() :: group.size()]
|
shortened = shortened[group.rank() :: group.size()]
|
||||||
completion_spans = completion_spans[group.rank() :: group.size()]
|
completion_spans = completion_spans[group.rank() :: group.size()]
|
||||||
|
|
||||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||||
results = self._loglikelihood(
|
scores, is_greedy = self._loglikelihood(
|
||||||
shortened,
|
shortened,
|
||||||
score_spans=completion_spans,
|
score_spans=completion_spans,
|
||||||
)
|
)
|
||||||
|
|
||||||
scores = mx.array([r[0] for r in results])
|
|
||||||
is_greedy = mx.array([r[1] == r[2] for r in results])
|
|
||||||
|
|
||||||
# all gather the results across groups
|
# all gather the results across groups
|
||||||
if group is not None:
|
if group.size() > 1:
|
||||||
per_group = int(np.ceil(num_results / group.size()))
|
per_group = int(np.ceil(num_results / group.size()))
|
||||||
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
||||||
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
||||||
@ -237,7 +237,15 @@ class MLXLM(LM):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||||
|
) -> str:
|
||||||
|
if len(chat_history) == 0:
|
||||||
|
return ""
|
||||||
|
return lm_eval.models.huggingface.HFLM.apply_chat_template(
|
||||||
|
chat_history, add_generation_prompt
|
||||||
|
)
|
||||||
|
|
||||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||||
@ -275,7 +283,8 @@ class MLXLM(LM):
|
|||||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||||
)
|
)
|
||||||
inputs = self._tokenize([req.args[0] for req in requests])
|
inputs = self._tokenize([req.args[0] for req in requests])
|
||||||
return [t[0] for t in self._loglikelihood(inputs)]
|
scores, _ = self._loglikelihood(inputs)
|
||||||
|
return scores.tolist()
|
||||||
|
|
||||||
def generate_until(self, requests) -> list[str]:
|
def generate_until(self, requests) -> list[str]:
|
||||||
"""Generate greedily until a stopping sequence
|
"""Generate greedily until a stopping sequence
|
||||||
@ -338,7 +347,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--limit",
|
"--limit",
|
||||||
default=1.0,
|
default=None,
|
||||||
help="Limit the number of examples per task.",
|
help="Limit the number of examples per task.",
|
||||||
type=float,
|
type=float,
|
||||||
)
|
)
|
||||||
@ -352,11 +361,8 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--apply-chat-template",
|
"--apply-chat-template",
|
||||||
action=argparse.BooleanOptionalAction,
|
action="store_true",
|
||||||
help="Specifies whether to apply a chat template to the prompt. If "
|
help="Specifies whether to apply a chat template to the prompt.",
|
||||||
"the model has a chat template, this defaults to `True`, "
|
|
||||||
"otherwise `False`.",
|
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user