mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-21 19:11:14 +08:00
fix(mlx-lm): type hints in gguf.py (#621)
This commit is contained in:
parent
0ab01b4626
commit
297a908e3d
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Union
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@ -23,7 +23,9 @@ class GGMLFileType(IntEnum):
|
|||||||
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||||
class HfVocab:
|
class HfVocab:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, fname_tokenizer: Path, fname_added_tokens: Path | None = None
|
self,
|
||||||
|
fname_tokenizer: Path,
|
||||||
|
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
fname_tokenizer,
|
fname_tokenizer,
|
||||||
@ -50,7 +52,7 @@ class HfVocab:
|
|||||||
self.fname_tokenizer = fname_tokenizer
|
self.fname_tokenizer = fname_tokenizer
|
||||||
self.fname_added_tokens = fname_added_tokens
|
self.fname_added_tokens = fname_added_tokens
|
||||||
|
|
||||||
def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
reverse_vocab = {
|
reverse_vocab = {
|
||||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||||
}
|
}
|
||||||
@ -63,7 +65,7 @@ class HfVocab:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_token_type(
|
def get_token_type(
|
||||||
self, token_id: int, token_text: bytes, special_ids: set[int]
|
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||||
) -> TokenType:
|
) -> TokenType:
|
||||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||||
return TokenType.BYTE
|
return TokenType.BYTE
|
||||||
@ -72,7 +74,7 @@ class HfVocab:
|
|||||||
def get_token_score(self, token_id: int) -> float:
|
def get_token_score(self, token_id: int) -> float:
|
||||||
return -1000.0
|
return -1000.0
|
||||||
|
|
||||||
def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
for text in self.added_tokens_list:
|
for text in self.added_tokens_list:
|
||||||
if text in self.specials:
|
if text in self.specials:
|
||||||
toktype = self.get_token_type(
|
toktype = self.get_token_type(
|
||||||
@ -87,7 +89,7 @@ class HfVocab:
|
|||||||
def has_newline_token(self):
|
def has_newline_token(self):
|
||||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||||
|
|
||||||
def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
yield from self.hf_tokens()
|
yield from self.hf_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
|
58
llms/tests/test_gguf.py
Normal file
58
llms/tests/test_gguf.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.gguf import convert_to_gguf
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertToGGUFWithoutMocks(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||||
|
cls.test_dir = cls.test_dir_fid.name
|
||||||
|
cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json")
|
||||||
|
with open(cls.tokenizer_file_path, "w") as f:
|
||||||
|
f.write("{}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||||
|
@patch("mlx.core.save_gguf")
|
||||||
|
def test_convert_to_gguf(
|
||||||
|
self,
|
||||||
|
mock_save_gguf,
|
||||||
|
mock_from_pretrained,
|
||||||
|
):
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.vocab_size = 3
|
||||||
|
mock_tokenizer.get_added_vocab.return_value = {}
|
||||||
|
mock_tokenizer.get_vocab.return_value = {"<pad>": 0, "hello": 1, "world": 2}
|
||||||
|
mock_tokenizer.all_special_tokens = ["<pad>"]
|
||||||
|
mock_tokenizer.all_special_ids = [0]
|
||||||
|
mock_from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
model_path = Path(self.test_dir)
|
||||||
|
weights = {
|
||||||
|
"self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]),
|
||||||
|
}
|
||||||
|
config = {
|
||||||
|
"num_attention_heads": 1,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"_name_or_path": "test-llama",
|
||||||
|
}
|
||||||
|
output_file_path = "/fake/output/path/gguf_model.gguf"
|
||||||
|
|
||||||
|
convert_to_gguf(model_path, weights, config, output_file_path)
|
||||||
|
called_args, _ = mock_save_gguf.call_args
|
||||||
|
self.assertEqual(called_args[0], output_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
|||||||
from mlx_lm.sample_utils import top_p_sampling
|
from mlx_lm.sample_utils import top_p_sampling
|
||||||
|
|
||||||
|
|
||||||
class TestLora(unittest.TestCase):
|
class TestSamplingUtils(unittest.TestCase):
|
||||||
@patch("mlx.core.random.categorical")
|
@patch("mlx.core.random.categorical")
|
||||||
def test_top_p_sampling(self, mock_categorical):
|
def test_top_p_sampling(self, mock_categorical):
|
||||||
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
|
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
|
||||||
|
Loading…
Reference in New Issue
Block a user