From 297a908e3db205a5d35acf806a30addbb2515788 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Wed, 27 Mar 2024 01:56:01 +1100
Subject: [PATCH] fix(mlx-lm): type hints in gguf.py (#621)
---
llms/mlx_lm/gguf.py | 14 ++++----
llms/tests/test_gguf.py | 58 +++++++++++++++++++++++++++++++++
llms/tests/test_sample_utils.py | 2 +-
3 files changed, 67 insertions(+), 7 deletions(-)
create mode 100644 llms/tests/test_gguf.py
diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py
index 382e1dce..1f858d70 100644
--- a/llms/mlx_lm/gguf.py
+++ b/llms/mlx_lm/gguf.py
@@ -1,7 +1,7 @@
import re
from enum import IntEnum
from pathlib import Path
-from typing import Iterable, Union
+from typing import Iterable, Optional, Set, Tuple, Union
import mlx.core as mx
from transformers import AutoTokenizer
@@ -23,7 +23,9 @@ class GGMLFileType(IntEnum):
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
class HfVocab:
def __init__(
- self, fname_tokenizer: Path, fname_added_tokens: Path | None = None
+ self,
+ fname_tokenizer: Path,
+ fname_added_tokens: Optional[Union[Path, None]] = None,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
fname_tokenizer,
@@ -50,7 +52,7 @@ class HfVocab:
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
- def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
+ def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
reverse_vocab = {
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
}
@@ -63,7 +65,7 @@ class HfVocab:
)
def get_token_type(
- self, token_id: int, token_text: bytes, special_ids: set[int]
+ self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
@@ -72,7 +74,7 @@ class HfVocab:
def get_token_score(self, token_id: int) -> float:
return -1000.0
- def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
+ def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(
@@ -87,7 +89,7 @@ class HfVocab:
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
- def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
+ def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
yield from self.hf_tokens()
yield from self.added_tokens()
diff --git a/llms/tests/test_gguf.py b/llms/tests/test_gguf.py
new file mode 100644
index 00000000..24ca64aa
--- /dev/null
+++ b/llms/tests/test_gguf.py
@@ -0,0 +1,58 @@
+import os
+import tempfile
+import unittest
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import mlx.core as mx
+from mlx_lm.gguf import convert_to_gguf
+
+
+class TestConvertToGGUFWithoutMocks(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.test_dir_fid = tempfile.TemporaryDirectory()
+ cls.test_dir = cls.test_dir_fid.name
+ cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json")
+ with open(cls.tokenizer_file_path, "w") as f:
+ f.write("{}")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.test_dir_fid.cleanup()
+
+ @patch("transformers.AutoTokenizer.from_pretrained")
+ @patch("mlx.core.save_gguf")
+ def test_convert_to_gguf(
+ self,
+ mock_save_gguf,
+ mock_from_pretrained,
+ ):
+ mock_tokenizer = MagicMock()
+ mock_tokenizer.vocab_size = 3
+ mock_tokenizer.get_added_vocab.return_value = {}
+ mock_tokenizer.get_vocab.return_value = {"": 0, "hello": 1, "world": 2}
+ mock_tokenizer.all_special_tokens = [""]
+ mock_tokenizer.all_special_ids = [0]
+ mock_from_pretrained.return_value = mock_tokenizer
+
+ model_path = Path(self.test_dir)
+ weights = {
+ "self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]),
+ }
+ config = {
+ "num_attention_heads": 1,
+ "num_hidden_layers": 1,
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "_name_or_path": "test-llama",
+ }
+ output_file_path = "/fake/output/path/gguf_model.gguf"
+
+ convert_to_gguf(model_path, weights, config, output_file_path)
+ called_args, _ = mock_save_gguf.call_args
+ self.assertEqual(called_args[0], output_file_path)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py
index 8b960736..0bccdd07 100644
--- a/llms/tests/test_sample_utils.py
+++ b/llms/tests/test_sample_utils.py
@@ -5,7 +5,7 @@ import mlx.core as mx
from mlx_lm.sample_utils import top_p_sampling
-class TestLora(unittest.TestCase):
+class TestSamplingUtils(unittest.TestCase):
@patch("mlx.core.random.categorical")
def test_top_p_sampling(self, mock_categorical):
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])