From 297a908e3db205a5d35acf806a30addbb2515788 Mon Sep 17 00:00:00 2001 From: Anchen Date: Wed, 27 Mar 2024 01:56:01 +1100 Subject: [PATCH] fix(mlx-lm): type hints in gguf.py (#621) --- llms/mlx_lm/gguf.py | 14 ++++---- llms/tests/test_gguf.py | 58 +++++++++++++++++++++++++++++++++ llms/tests/test_sample_utils.py | 2 +- 3 files changed, 67 insertions(+), 7 deletions(-) create mode 100644 llms/tests/test_gguf.py diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py index 382e1dce..1f858d70 100644 --- a/llms/mlx_lm/gguf.py +++ b/llms/mlx_lm/gguf.py @@ -1,7 +1,7 @@ import re from enum import IntEnum from pathlib import Path -from typing import Iterable, Union +from typing import Iterable, Optional, Set, Tuple, Union import mlx.core as mx from transformers import AutoTokenizer @@ -23,7 +23,9 @@ class GGMLFileType(IntEnum): # copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455 class HfVocab: def __init__( - self, fname_tokenizer: Path, fname_added_tokens: Path | None = None + self, + fname_tokenizer: Path, + fname_added_tokens: Optional[Union[Path, None]] = None, ) -> None: self.tokenizer = AutoTokenizer.from_pretrained( fname_tokenizer, @@ -50,7 +52,7 @@ class HfVocab: self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]: + def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: reverse_vocab = { id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() } @@ -63,7 +65,7 @@ class HfVocab: ) def get_token_type( - self, token_id: int, token_text: bytes, special_ids: set[int] + self, token_id: int, token_text: bytes, special_ids: Set[int] ) -> TokenType: if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): return TokenType.BYTE @@ -72,7 +74,7 @@ class HfVocab: def get_token_score(self, token_id: int) -> float: return -1000.0 - def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]: + def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: for text in self.added_tokens_list: if text in self.specials: toktype = self.get_token_type( @@ -87,7 +89,7 @@ class HfVocab: def has_newline_token(self): return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab - def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]: + def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: yield from self.hf_tokens() yield from self.added_tokens() diff --git a/llms/tests/test_gguf.py b/llms/tests/test_gguf.py new file mode 100644 index 00000000..24ca64aa --- /dev/null +++ b/llms/tests/test_gguf.py @@ -0,0 +1,58 @@ +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import mlx.core as mx +from mlx_lm.gguf import convert_to_gguf + + +class TestConvertToGGUFWithoutMocks(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json") + with open(cls.tokenizer_file_path, "w") as f: + f.write("{}") + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + @patch("transformers.AutoTokenizer.from_pretrained") + @patch("mlx.core.save_gguf") + def test_convert_to_gguf( + self, + mock_save_gguf, + mock_from_pretrained, + ): + mock_tokenizer = MagicMock() + mock_tokenizer.vocab_size = 3 + mock_tokenizer.get_added_vocab.return_value = {} + mock_tokenizer.get_vocab.return_value = {"": 0, "hello": 1, "world": 2} + mock_tokenizer.all_special_tokens = [""] + mock_tokenizer.all_special_ids = [0] + mock_from_pretrained.return_value = mock_tokenizer + + model_path = Path(self.test_dir) + weights = { + "self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]), + } + config = { + "num_attention_heads": 1, + "num_hidden_layers": 1, + "hidden_size": 768, + "intermediate_size": 3072, + "_name_or_path": "test-llama", + } + output_file_path = "/fake/output/path/gguf_model.gguf" + + convert_to_gguf(model_path, weights, config, output_file_path) + called_args, _ = mock_save_gguf.call_args + self.assertEqual(called_args[0], output_file_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index 8b960736..0bccdd07 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -5,7 +5,7 @@ import mlx.core as mx from mlx_lm.sample_utils import top_p_sampling -class TestLora(unittest.TestCase): +class TestSamplingUtils(unittest.TestCase): @patch("mlx.core.random.categorical") def test_top_p_sampling(self, mock_categorical): logits = mx.array([[1.0, 2.0, 3.0, 4.0]])