mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 09:31:13 +08:00
fix(mlx-lm): type hints in gguf.py (#621)
This commit is contained in:
parent
0ab01b4626
commit
297a908e3d
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from transformers import AutoTokenizer
|
||||
@ -23,7 +23,9 @@ class GGMLFileType(IntEnum):
|
||||
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||
class HfVocab:
|
||||
def __init__(
|
||||
self, fname_tokenizer: Path, fname_added_tokens: Path | None = None
|
||||
self,
|
||||
fname_tokenizer: Path,
|
||||
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||
) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
fname_tokenizer,
|
||||
@ -50,7 +52,7 @@ class HfVocab:
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
self.fname_added_tokens = fname_added_tokens
|
||||
|
||||
def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
||||
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
}
|
||||
@ -63,7 +65,7 @@ class HfVocab:
|
||||
)
|
||||
|
||||
def get_token_type(
|
||||
self, token_id: int, token_text: bytes, special_ids: set[int]
|
||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||
) -> TokenType:
|
||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
return TokenType.BYTE
|
||||
@ -72,7 +74,7 @@ class HfVocab:
|
||||
def get_token_score(self, token_id: int) -> float:
|
||||
return -1000.0
|
||||
|
||||
def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
if text in self.specials:
|
||||
toktype = self.get_token_type(
|
||||
@ -87,7 +89,7 @@ class HfVocab:
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
|
58
llms/tests/test_gguf.py
Normal file
58
llms/tests/test_gguf.py
Normal file
@ -0,0 +1,58 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.gguf import convert_to_gguf
|
||||
|
||||
|
||||
class TestConvertToGGUFWithoutMocks(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json")
|
||||
with open(cls.tokenizer_file_path, "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||
@patch("mlx.core.save_gguf")
|
||||
def test_convert_to_gguf(
|
||||
self,
|
||||
mock_save_gguf,
|
||||
mock_from_pretrained,
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.vocab_size = 3
|
||||
mock_tokenizer.get_added_vocab.return_value = {}
|
||||
mock_tokenizer.get_vocab.return_value = {"<pad>": 0, "hello": 1, "world": 2}
|
||||
mock_tokenizer.all_special_tokens = ["<pad>"]
|
||||
mock_tokenizer.all_special_ids = [0]
|
||||
mock_from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
model_path = Path(self.test_dir)
|
||||
weights = {
|
||||
"self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]),
|
||||
}
|
||||
config = {
|
||||
"num_attention_heads": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"_name_or_path": "test-llama",
|
||||
}
|
||||
output_file_path = "/fake/output/path/gguf_model.gguf"
|
||||
|
||||
convert_to_gguf(model_path, weights, config, output_file_path)
|
||||
called_args, _ = mock_save_gguf.call_args
|
||||
self.assertEqual(called_args[0], output_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
from mlx_lm.sample_utils import top_p_sampling
|
||||
|
||||
|
||||
class TestLora(unittest.TestCase):
|
||||
class TestSamplingUtils(unittest.TestCase):
|
||||
@patch("mlx.core.random.categorical")
|
||||
def test_top_p_sampling(self, mock_categorical):
|
||||
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
|
||||
|
Loading…
Reference in New Issue
Block a user