mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add support for fewshot and apply chat template lm_eval functionality (#1180)
* Add support for multiturn fewshot examples and chat templates Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results * Add HF overrides for methods needed by added options * don't add duplicate bos --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
25ec2d8c44
commit
f2619f507c
@ -32,7 +32,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install unittest-xml-reporting
|
||||
cd llms/
|
||||
pip install -e ".[testing]"
|
||||
pip install -e ".[test]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
|
@ -77,15 +77,19 @@ class MLXLM(LM):
|
||||
path_or_hf_repo: str,
|
||||
batch_size: int = 16,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_chat_template: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._batch_size = batch_size
|
||||
self._model, self._tokenizer = load(path_or_hf_repo)
|
||||
self._max_tokens = max_tokens or self._tokenizer.model_max_length
|
||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||
self.use_chat_template = use_chat_template or (
|
||||
self.tokenizer.chat_template is not None
|
||||
)
|
||||
|
||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
||||
if tokenize:
|
||||
inputs = self._tokenizer.encode(inputs)
|
||||
inputs = self._tokenize(inputs)
|
||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
||||
inputs = mx.array(inputs)
|
||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||
@ -149,7 +153,12 @@ class MLXLM(LM):
|
||||
return results
|
||||
|
||||
def _tokenize(self, texts):
|
||||
return [tuple(self._tokenizer.encode(t)) for t in texts]
|
||||
return [
|
||||
tuple(
|
||||
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
|
||||
)
|
||||
for t in texts
|
||||
]
|
||||
|
||||
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
||||
"""Compute log-likelihood of generating a continuation from a context.
|
||||
@ -221,6 +230,9 @@ class MLXLM(LM):
|
||||
)
|
||||
return [(r[0], r[1] == r[2]) for r in results]
|
||||
|
||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||
|
||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||
- We will use the full max context length of the model.
|
||||
@ -283,21 +295,14 @@ class MLXLM(LM):
|
||||
completions = []
|
||||
|
||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||
if self._tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": context}]
|
||||
context = self._tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
context = self._tokenizer.encode(context)
|
||||
|
||||
context = self._tokenize(context)
|
||||
max_tokens = min(
|
||||
self._max_tokens,
|
||||
self._tokenizer.model_max_length - len(context),
|
||||
self.tokenizer.model_max_length - len(context),
|
||||
)
|
||||
text = ""
|
||||
for response in stream_generate(
|
||||
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
|
||||
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
|
||||
):
|
||||
text += response.text
|
||||
if any(u in text for u in until):
|
||||
@ -332,6 +337,21 @@ def main():
|
||||
type=float,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--fewshot-as-multiturn",
|
||||
action="store_true",
|
||||
help="Whether to provide the fewshot examples as a multiturn "
|
||||
"conversation or a single user turn.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether to apply a chat template to the prompt. If "
|
||||
"the model has a chat template, this defaults to `True`, "
|
||||
"otherwise `False`.",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
@ -342,18 +362,23 @@ def main():
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
|
||||
|
||||
lm = MLXLM(
|
||||
args.model,
|
||||
batch_size=args.batch_size,
|
||||
max_tokens=args.max_tokens,
|
||||
use_chat_template=args.apply_chat_template,
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model=lm,
|
||||
tasks=args.tasks,
|
||||
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
||||
apply_chat_template=lm.use_chat_template,
|
||||
num_fewshot=args.num_shots,
|
||||
limit=args.limit,
|
||||
random_seed=args.seed,
|
||||
numpy_random_seed=args.seed,
|
||||
torch_random_seed=args.seed,
|
||||
fewshot_random_seed=args.seed,
|
||||
apply_chat_template=True,
|
||||
)
|
||||
|
||||
model_name = args.model.replace("/", "_")
|
||||
|
@ -27,8 +27,8 @@ setup(
|
||||
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
||||
python_requires=">=3.8",
|
||||
extras_require={
|
||||
"testing": ["datasets"],
|
||||
"evaluation": ["lm-eval"],
|
||||
"test": ["datasets"],
|
||||
"evaluate": ["lm-eval", "tqdm"],
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
|
Loading…
Reference in New Issue
Block a user