fix(mlx-lm): type hints in gguf.py (#621)

This commit is contained in:
Anchen 2024-03-27 01:56:01 +11:00 committed by GitHub
parent 0ab01b4626
commit 297a908e3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 7 deletions

View File

@ -1,7 +1,7 @@
import re
from enum import IntEnum
from pathlib import Path
from typing import Iterable, Union
from typing import Iterable, Optional, Set, Tuple, Union
import mlx.core as mx
from transformers import AutoTokenizer
@ -23,7 +23,9 @@ class GGMLFileType(IntEnum):
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
class HfVocab:
def __init__(
self, fname_tokenizer: Path, fname_added_tokens: Path | None = None
self,
fname_tokenizer: Path,
fname_added_tokens: Optional[Union[Path, None]] = None,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
fname_tokenizer,
@ -50,7 +52,7 @@ class HfVocab:
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
reverse_vocab = {
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
}
@ -63,7 +65,7 @@ class HfVocab:
)
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: set[int]
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
@ -72,7 +74,7 @@ class HfVocab:
def get_token_score(self, token_id: int) -> float:
return -1000.0
def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(
@ -87,7 +89,7 @@ class HfVocab:
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
yield from self.hf_tokens()
yield from self.added_tokens()

58
llms/tests/test_gguf.py Normal file
View File

@ -0,0 +1,58 @@
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import mlx.core as mx
from mlx_lm.gguf import convert_to_gguf
class TestConvertToGGUFWithoutMocks(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json")
with open(cls.tokenizer_file_path, "w") as f:
f.write("{}")
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("mlx.core.save_gguf")
def test_convert_to_gguf(
self,
mock_save_gguf,
mock_from_pretrained,
):
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 3
mock_tokenizer.get_added_vocab.return_value = {}
mock_tokenizer.get_vocab.return_value = {"<pad>": 0, "hello": 1, "world": 2}
mock_tokenizer.all_special_tokens = ["<pad>"]
mock_tokenizer.all_special_ids = [0]
mock_from_pretrained.return_value = mock_tokenizer
model_path = Path(self.test_dir)
weights = {
"self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]),
}
config = {
"num_attention_heads": 1,
"num_hidden_layers": 1,
"hidden_size": 768,
"intermediate_size": 3072,
"_name_or_path": "test-llama",
}
output_file_path = "/fake/output/path/gguf_model.gguf"
convert_to_gguf(model_path, weights, config, output_file_path)
called_args, _ = mock_save_gguf.call_args
self.assertEqual(called_args[0], output_file_path)
if __name__ == "__main__":
unittest.main()

View File

@ -5,7 +5,7 @@ import mlx.core as mx
from mlx_lm.sample_utils import top_p_sampling
class TestLora(unittest.TestCase):
class TestSamplingUtils(unittest.TestCase):
@patch("mlx.core.random.categorical")
def test_top_p_sampling(self, mock_categorical):
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])