Add HF overrides for methods needed by added options

This commit is contained in:
Chime Ogbuji 2024-12-23 12:02:37 -05:00 committed by Awni Hannun
parent d352074e73
commit e1072b5300

View File

@ -83,6 +83,9 @@ class MLXLM(LM):
self._model, self._tokenizer = load(path_or_hf_repo) self._model, self._tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self._tokenizer.model_max_length self._max_tokens = max_tokens or self._tokenizer.model_max_length
# Needed by HF implementation methods (tokenizer_name, apply_chat_template, and, tok_encode)
self.tokenizer = self._tokenizer
def _score_fn(self, inputs, tokenize=True, step_size=32): def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize: if tokenize:
inputs = self._tokenizer.encode(inputs) inputs = self._tokenizer.encode(inputs)
@ -221,6 +224,10 @@ class MLXLM(LM):
) )
return [(r[0], r[1] == r[2]) for r in results] return [(r[0], r[1] == r[2]) for r in results]
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
tok_encode = lm_eval.models.huggingface.HFLM.tok_encode
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.