mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
Merge branch 'ml-explore:main' into add_modelscope
This commit is contained in:
commit
5c472e7721
@ -32,7 +32,7 @@ jobs:
|
|||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install unittest-xml-reporting
|
pip install unittest-xml-reporting
|
||||||
cd llms/
|
cd llms/
|
||||||
pip install -e ".[testing]"
|
pip install -e ".[test]"
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
|
@ -77,15 +77,19 @@ class MLXLM(LM):
|
|||||||
path_or_hf_repo: str,
|
path_or_hf_repo: str,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
|
use_chat_template: Optional[bool] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._model, self._tokenizer = load(path_or_hf_repo)
|
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||||
self._max_tokens = max_tokens or self._tokenizer.model_max_length
|
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||||
|
self.use_chat_template = use_chat_template or (
|
||||||
|
self.tokenizer.chat_template is not None
|
||||||
|
)
|
||||||
|
|
||||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
||||||
if tokenize:
|
if tokenize:
|
||||||
inputs = self._tokenizer.encode(inputs)
|
inputs = self._tokenize(inputs)
|
||||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
||||||
inputs = mx.array(inputs)
|
inputs = mx.array(inputs)
|
||||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||||
@ -149,7 +153,12 @@ class MLXLM(LM):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def _tokenize(self, texts):
|
def _tokenize(self, texts):
|
||||||
return [tuple(self._tokenizer.encode(t)) for t in texts]
|
return [
|
||||||
|
tuple(
|
||||||
|
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
|
||||||
|
)
|
||||||
|
for t in texts
|
||||||
|
]
|
||||||
|
|
||||||
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
||||||
"""Compute log-likelihood of generating a continuation from a context.
|
"""Compute log-likelihood of generating a continuation from a context.
|
||||||
@ -221,6 +230,9 @@ class MLXLM(LM):
|
|||||||
)
|
)
|
||||||
return [(r[0], r[1] == r[2]) for r in results]
|
return [(r[0], r[1] == r[2]) for r in results]
|
||||||
|
|
||||||
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||||
|
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||||
|
|
||||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||||
- We will use the full max context length of the model.
|
- We will use the full max context length of the model.
|
||||||
@ -283,21 +295,14 @@ class MLXLM(LM):
|
|||||||
completions = []
|
completions = []
|
||||||
|
|
||||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||||
if self._tokenizer.chat_template is not None:
|
context = self._tokenize(context)
|
||||||
messages = [{"role": "user", "content": context}]
|
|
||||||
context = self._tokenizer.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
context = self._tokenizer.encode(context)
|
|
||||||
|
|
||||||
max_tokens = min(
|
max_tokens = min(
|
||||||
self._max_tokens,
|
self._max_tokens,
|
||||||
self._tokenizer.model_max_length - len(context),
|
self.tokenizer.model_max_length - len(context),
|
||||||
)
|
)
|
||||||
text = ""
|
text = ""
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
self._model, self._tokenizer, prompt=context, max_tokens=max_tokens
|
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
|
||||||
):
|
):
|
||||||
text += response.text
|
text += response.text
|
||||||
if any(u in text for u in until):
|
if any(u in text for u in until):
|
||||||
@ -332,6 +337,21 @@ def main():
|
|||||||
type=float,
|
type=float,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fewshot-as-multiturn",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to provide the fewshot examples as a multiturn "
|
||||||
|
"conversation or a single user turn.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply-chat-template",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help="Specifies whether to apply a chat template to the prompt. If "
|
||||||
|
"the model has a chat template, this defaults to `True`, "
|
||||||
|
"otherwise `False`.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
@ -342,18 +362,23 @@ def main():
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens)
|
lm = MLXLM(
|
||||||
|
args.model,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
use_chat_template=args.apply_chat_template,
|
||||||
|
)
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model=lm,
|
model=lm,
|
||||||
tasks=args.tasks,
|
tasks=args.tasks,
|
||||||
|
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
||||||
|
apply_chat_template=lm.use_chat_template,
|
||||||
num_fewshot=args.num_shots,
|
num_fewshot=args.num_shots,
|
||||||
limit=args.limit,
|
limit=args.limit,
|
||||||
random_seed=args.seed,
|
random_seed=args.seed,
|
||||||
numpy_random_seed=args.seed,
|
numpy_random_seed=args.seed,
|
||||||
torch_random_seed=args.seed,
|
torch_random_seed=args.seed,
|
||||||
fewshot_random_seed=args.seed,
|
fewshot_random_seed=args.seed,
|
||||||
apply_chat_template=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = args.model.replace("/", "_")
|
model_name = args.model.replace("/", "_")
|
||||||
|
@ -43,10 +43,11 @@ def setup_arg_parser():
|
|||||||
help="Optional path for the trained adapter weights and config.",
|
help="Optional path for the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eos-token",
|
"--extra-eos-token",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=(),
|
||||||
help="End of sequence token for tokenizer",
|
nargs="+",
|
||||||
|
help="Add tokens in the list of eos tokens that stop generation.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--system-prompt",
|
"--system-prompt",
|
||||||
@ -161,8 +162,6 @@ def main():
|
|||||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||||
)
|
)
|
||||||
tokenizer_config["trust_remote_code"] = True
|
tokenizer_config["trust_remote_code"] = True
|
||||||
if args.eos_token is not None:
|
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
|
||||||
|
|
||||||
model_path = args.model
|
model_path = args.model
|
||||||
if using_cache:
|
if using_cache:
|
||||||
@ -181,6 +180,8 @@ def main():
|
|||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
)
|
)
|
||||||
|
for eos_token in args.extra_eos_token:
|
||||||
|
tokenizer.add_eos_token(eos_token)
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
|
@ -266,6 +266,18 @@ class TokenizerWrapper:
|
|||||||
else {tokenizer.eos_token_id}
|
else {tokenizer.eos_token_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_eos_token(self, token: str):
|
||||||
|
token_id = None
|
||||||
|
try:
|
||||||
|
token_id = int(token)
|
||||||
|
except ValueError:
|
||||||
|
token_id = self._tokenizer.convert_tokens_to_ids(token)
|
||||||
|
|
||||||
|
if token_id is None:
|
||||||
|
raise ValueError(f"'{token}' is not a token for this tokenizer")
|
||||||
|
|
||||||
|
self._eos_token_ids.add(token_id)
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
if attr == "detokenizer":
|
if attr == "detokenizer":
|
||||||
return self._detokenizer
|
return self._detokenizer
|
||||||
|
@ -697,12 +697,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||||
api.upload_folder(
|
api.upload_large_folder(
|
||||||
folder_path=path,
|
folder_path=path,
|
||||||
repo_id=upload_repo,
|
repo_id=upload_repo,
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
multi_commits=True,
|
|
||||||
multi_commits_verbose=True,
|
|
||||||
)
|
)
|
||||||
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
||||||
|
|
||||||
|
@ -27,8 +27,8 @@ setup(
|
|||||||
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
extras_require={
|
extras_require={
|
||||||
"testing": ["datasets"],
|
"test": ["datasets"],
|
||||||
"evaluation": ["lm-eval"],
|
"evaluate": ["lm-eval", "tqdm"],
|
||||||
},
|
},
|
||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
|
Loading…
Reference in New Issue
Block a user