mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
distributed evaluate
This commit is contained in:
parent
9a3ddc3e65
commit
4385363c0f
@ -20,11 +20,10 @@ from lm_eval.api.model import LM
|
|||||||
from lm_eval.api.registry import register_model
|
from lm_eval.api.registry import register_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .models.base import create_causal_mask
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
PAD = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _len_longest_common_prefix(a, b):
|
def _len_longest_common_prefix(a, b):
|
||||||
l = 0
|
l = 0
|
||||||
@ -43,31 +42,14 @@ def _rstrip_until(s, untils):
|
|||||||
return s[: min(f)]
|
return s[: min(f)]
|
||||||
|
|
||||||
|
|
||||||
def _pad_inputs(
|
def _pad_inputs(inputs):
|
||||||
inputs,
|
lengths = mx.array([len(x) for x in inputs])
|
||||||
maxlen,
|
maxlen = lengths.max()
|
||||||
genlen=0,
|
padded = mx.stack(
|
||||||
pad_left=False,
|
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs],
|
||||||
pad_multiple=32,
|
|
||||||
truncate=False,
|
|
||||||
):
|
|
||||||
# pad the prompts to the left with at least genlen tokens.
|
|
||||||
actual_maxlen = max(len(p) for p in inputs) + genlen
|
|
||||||
if actual_maxlen > maxlen:
|
|
||||||
if not truncate:
|
|
||||||
raise ValueError("Inputs are too long.")
|
|
||||||
else: # drop begining
|
|
||||||
actual_maxlen = maxlen
|
|
||||||
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
|
||||||
if pad_multiple > 0:
|
|
||||||
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
|
||||||
maxlen *= pad_multiple
|
|
||||||
assert PAD == 0
|
|
||||||
lr = np.array((1, 0) if pad_left else (0, 1))
|
|
||||||
return np.stack(
|
|
||||||
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
return padded, lengths
|
||||||
|
|
||||||
|
|
||||||
@register_model("mlxlm")
|
@register_model("mlxlm")
|
||||||
@ -87,28 +69,31 @@ class MLXLM(LM):
|
|||||||
self.tokenizer.chat_template is not None
|
self.tokenizer.chat_template is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
def _score_fn(self, inputs, step_size=64):
|
||||||
if tokenize:
|
inputs, lengths = _pad_inputs(inputs)
|
||||||
inputs = self._tokenize(inputs)
|
|
||||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
|
||||||
inputs = mx.array(inputs)
|
|
||||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||||
|
|
||||||
cache = make_prompt_cache(self._model)
|
cache = make_prompt_cache(self._model)
|
||||||
|
|
||||||
mask = targets != PAD
|
# TODO: come up with a better way to get the dtype
|
||||||
|
dtype = self._model.model.embed_tokens(inputs).dtype
|
||||||
|
|
||||||
scores, is_greedy = [], []
|
scores, is_greedy = [], []
|
||||||
for i in range(0, inputs.shape[1], step_size):
|
for i in range(0, inputs.shape[1], step_size):
|
||||||
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
inp = inputs[:, i : i + step_size]
|
||||||
|
T = inp.shape[1]
|
||||||
|
|
||||||
|
offset = cache[0].offset
|
||||||
|
mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype)
|
||||||
|
|
||||||
|
logits = self._model(inp, cache=cache, mask=mask)
|
||||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||||
|
|
||||||
score = mx.take_along_axis(
|
score = mx.take_along_axis(
|
||||||
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
||||||
)[..., 0]
|
)[..., 0]
|
||||||
ig = mask[:, i : i + step_size] * (
|
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
||||||
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
|
||||||
)
|
|
||||||
|
|
||||||
mx.eval(score, ig)
|
mx.eval(score, ig)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
@ -119,37 +104,26 @@ class MLXLM(LM):
|
|||||||
scores = mx.concatenate(scores, axis=1)
|
scores = mx.concatenate(scores, axis=1)
|
||||||
is_greedy = mx.concatenate(is_greedy, axis=1)
|
is_greedy = mx.concatenate(is_greedy, axis=1)
|
||||||
|
|
||||||
return scores, mask.sum(axis=-1), is_greedy
|
return scores, lengths, is_greedy
|
||||||
|
|
||||||
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
|
||||||
# sort by length to get batches with little padding.
|
|
||||||
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
|
||||||
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
|
||||||
sorted_spans = None
|
|
||||||
if score_spans is not None:
|
|
||||||
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
|
||||||
|
|
||||||
|
def _loglikelihood(self, texts, score_spans=None):
|
||||||
results = []
|
results = []
|
||||||
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
for i in tqdm(range(0, len(texts), self._batch_size)):
|
||||||
batch = sorted_inputs[i : i + self._batch_size]
|
batch = texts[i : i + self._batch_size]
|
||||||
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
scores, length, is_greedy = self._score_fn(batch)
|
||||||
for j in range(len(batch)):
|
for j in range(len(batch)):
|
||||||
if sorted_spans is None: # full sequence score
|
if score_spans is None: # full sequence score
|
||||||
mask = mx.arange(scores[j].shape[-1]) < length
|
l = length[j].item()
|
||||||
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
score = scores[j][:l].astype(mx.float32).sum()
|
||||||
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
ig = is_greedy[j][:l].astype(mx.int32).sum()
|
||||||
else: # subsequence score
|
else: # subsequence score
|
||||||
start, end = sorted_spans[i + j]
|
start, end = score_spans[i + j]
|
||||||
score = scores[j][start:end].astype(mx.float32).sum()
|
score = scores[j][start:end].astype(mx.float32).sum()
|
||||||
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
||||||
length = end - start
|
length = end - start
|
||||||
|
|
||||||
results.append((score.item(), ig.item(), length))
|
results.append((score.item(), ig.item(), length))
|
||||||
|
|
||||||
# reorder the outputs
|
|
||||||
inv_sort = np.argsort(sorted_indices)
|
|
||||||
results = [results[inv_sort[i]] for i in range(len(results))]
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _tokenize(self, texts):
|
def _tokenize(self, texts):
|
||||||
@ -222,13 +196,45 @@ class MLXLM(LM):
|
|||||||
+ "completion longer than context."
|
+ "completion longer than context."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_results = len(shortened)
|
||||||
|
|
||||||
|
# sort by length to get batches with little padding.
|
||||||
|
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
|
||||||
|
shortened = [shortened[i] for i in sorted_indices]
|
||||||
|
completion_spans = [completion_spans[i] for i in sorted_indices]
|
||||||
|
|
||||||
|
group = mx.distributed.init() if mx.distributed.is_available() else None
|
||||||
|
if group is not None:
|
||||||
|
# split strided so we have approximately the same lengths on each node
|
||||||
|
shortened = shortened[group.rank() :: group.size()]
|
||||||
|
completion_spans = completion_spans[group.rank() :: group.size()]
|
||||||
|
|
||||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||||
results = self._loglikelihood(
|
results = self._loglikelihood(
|
||||||
shortened,
|
shortened,
|
||||||
score_spans=completion_spans,
|
score_spans=completion_spans,
|
||||||
tokenize=False,
|
|
||||||
)
|
)
|
||||||
return [(r[0], r[1] == r[2]) for r in results]
|
|
||||||
|
scores = mx.array([r[0] for r in results])
|
||||||
|
is_greedy = mx.array([r[1] == r[2] for r in results])
|
||||||
|
|
||||||
|
# all gather the results across groups
|
||||||
|
if group is not None:
|
||||||
|
per_group = int(np.ceil(num_results / group.size()))
|
||||||
|
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
||||||
|
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
|
||||||
|
scores = scores.T.reshape(-1)
|
||||||
|
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
|
||||||
|
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
||||||
|
is_greedy = is_greedy.T.reshape(-1)
|
||||||
|
|
||||||
|
scores = np.array(scores[:num_results])
|
||||||
|
is_greedy = np.array(is_greedy[:num_results])
|
||||||
|
|
||||||
|
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
|
||||||
|
inv_sort = np.argsort(sorted_indices)
|
||||||
|
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
|
||||||
|
return results
|
||||||
|
|
||||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||||
@ -268,7 +274,7 @@ class MLXLM(LM):
|
|||||||
logging.info(
|
logging.info(
|
||||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||||
)
|
)
|
||||||
inputs = [req.args[0] for req in requests]
|
inputs = self._tokenize([req.args[0] for req in requests])
|
||||||
return [t[0] for t in self._loglikelihood(inputs)]
|
return [t[0] for t in self._loglikelihood(inputs)]
|
||||||
|
|
||||||
def generate_until(self, requests) -> list[str]:
|
def generate_until(self, requests) -> list[str]:
|
||||||
|
Loading…
Reference in New Issue
Block a user