diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index ca5e83bb..8fa00dd2 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -20,11 +20,10 @@ from lm_eval.api.model import LM from lm_eval.api.registry import register_model from tqdm import tqdm +from .models.base import create_causal_mask from .models.cache import make_prompt_cache from .utils import load, stream_generate -PAD = 0 - def _len_longest_common_prefix(a, b): l = 0 @@ -43,31 +42,14 @@ def _rstrip_until(s, untils): return s[: min(f)] -def _pad_inputs( - inputs, - maxlen, - genlen=0, - pad_left=False, - pad_multiple=32, - truncate=False, -): - # pad the prompts to the left with at least genlen tokens. - actual_maxlen = max(len(p) for p in inputs) + genlen - if actual_maxlen > maxlen: - if not truncate: - raise ValueError("Inputs are too long.") - else: # drop begining - actual_maxlen = maxlen - inputs = [p[max(0, len(p) - maxlen) :] for p in inputs] - if pad_multiple > 0: - maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple - maxlen *= pad_multiple - assert PAD == 0 - lr = np.array((1, 0) if pad_left else (0, 1)) - return np.stack( - [np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs], +def _pad_inputs(inputs): + lengths = mx.array([len(x) for x in inputs]) + maxlen = lengths.max() + padded = mx.stack( + [mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs], axis=0, ) + return padded, lengths @register_model("mlxlm") @@ -87,28 +69,31 @@ class MLXLM(LM): self.tokenizer.chat_template is not None ) - def _score_fn(self, inputs, tokenize=True, step_size=32): - if tokenize: - inputs = self._tokenize(inputs) - inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) - inputs = mx.array(inputs) + def _score_fn(self, inputs, step_size=64): + inputs, lengths = _pad_inputs(inputs) inputs, targets = inputs[..., :-1], inputs[..., 1:] cache = make_prompt_cache(self._model) - mask = targets != PAD + # TODO: come up with a better way to get the dtype + dtype = self._model.model.embed_tokens(inputs).dtype scores, is_greedy = [], [] for i in range(0, inputs.shape[1], step_size): - logits = self._model(inputs[:, i : i + step_size], cache=cache) + inp = inputs[:, i : i + step_size] + T = inp.shape[1] + offset = cache[0].offset + mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype) + + logits = self._model(inp, cache=cache, mask=mask) log_probs = nn.log_softmax(logits.astype(mx.float32)) + score = mx.take_along_axis( log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 )[..., 0] - ig = mask[:, i : i + step_size] * ( - targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) - ) + ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) + ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False) mx.eval(score, ig) mx.metal.clear_cache() @@ -119,37 +104,26 @@ class MLXLM(LM): scores = mx.concatenate(scores, axis=1) is_greedy = mx.concatenate(is_greedy, axis=1) - return scores, mask.sum(axis=-1), is_greedy - - def _loglikelihood(self, texts, score_spans=None, tokenize=True): - # sort by length to get batches with little padding. - sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) - sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] - sorted_spans = None - if score_spans is not None: - sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))] + return scores, lengths, is_greedy + def _loglikelihood(self, texts, score_spans=None): results = [] - for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): - batch = sorted_inputs[i : i + self._batch_size] - scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) + for i in tqdm(range(0, len(texts), self._batch_size)): + batch = texts[i : i + self._batch_size] + scores, length, is_greedy = self._score_fn(batch) for j in range(len(batch)): - if sorted_spans is None: # full sequence score - mask = mx.arange(scores[j].shape[-1]) < length - score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) - ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) + if score_spans is None: # full sequence score + l = length[j].item() + score = scores[j][:l].astype(mx.float32).sum() + ig = is_greedy[j][:l].astype(mx.int32).sum() else: # subsequence score - start, end = sorted_spans[i + j] + start, end = score_spans[i + j] score = scores[j][start:end].astype(mx.float32).sum() ig = is_greedy[j][start:end].astype(mx.int32).sum() length = end - start results.append((score.item(), ig.item(), length)) - # reorder the outputs - inv_sort = np.argsort(sorted_indices) - results = [results[inv_sort[i]] for i in range(len(results))] - return results def _tokenize(self, texts): @@ -222,13 +196,45 @@ class MLXLM(LM): + "completion longer than context." ) + num_results = len(shortened) + + # sort by length to get batches with little padding. + sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i])) + shortened = [shortened[i] for i in sorted_indices] + completion_spans = [completion_spans[i] for i in sorted_indices] + + group = mx.distributed.init() if mx.distributed.is_available() else None + if group is not None: + # split strided so we have approximately the same lengths on each node + shortened = shortened[group.rank() :: group.size()] + completion_spans = completion_spans[group.rank() :: group.size()] + # model scoring, returns num_requests x (logp, is_greedy, length). results = self._loglikelihood( shortened, score_spans=completion_spans, - tokenize=False, ) - return [(r[0], r[1] == r[2]) for r in results] + + scores = mx.array([r[0] for r in results]) + is_greedy = mx.array([r[1] == r[2] for r in results]) + + # all gather the results across groups + if group is not None: + per_group = int(np.ceil(num_results / group.size())) + scores = mx.pad(scores, ((0, per_group - len(scores)),)) + scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu) + scores = scores.T.reshape(-1) + is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu) + is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy)))) + is_greedy = is_greedy.T.reshape(-1) + + scores = np.array(scores[:num_results]) + is_greedy = np.array(is_greedy[:num_results]) + + results = [(score, ig) for score, ig in zip(scores, is_greedy)] + inv_sort = np.argsort(sorted_indices) + results = [results[inv_sort[i]] for i in range(len(inv_sort))] + return results tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template @@ -268,7 +274,7 @@ class MLXLM(LM): logging.info( "Estimating loglikelihood rolling for %d sequences." % len(requests) ) - inputs = [req.args[0] for req in requests] + inputs = self._tokenize([req.args[0] for req in requests]) return [t[0] for t in self._loglikelihood(inputs)] def generate_until(self, requests) -> list[str]: