mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
fix(mlx-lm): type hints in gguf.py (#621)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from transformers import AutoTokenizer
|
||||
@@ -23,7 +23,9 @@ class GGMLFileType(IntEnum):
|
||||
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||
class HfVocab:
|
||||
def __init__(
|
||||
self, fname_tokenizer: Path, fname_added_tokens: Path | None = None
|
||||
self,
|
||||
fname_tokenizer: Path,
|
||||
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||
) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
fname_tokenizer,
|
||||
@@ -50,7 +52,7 @@ class HfVocab:
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
self.fname_added_tokens = fname_added_tokens
|
||||
|
||||
def hf_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
||||
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
}
|
||||
@@ -63,7 +65,7 @@ class HfVocab:
|
||||
)
|
||||
|
||||
def get_token_type(
|
||||
self, token_id: int, token_text: bytes, special_ids: set[int]
|
||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||
) -> TokenType:
|
||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
return TokenType.BYTE
|
||||
@@ -72,7 +74,7 @@ class HfVocab:
|
||||
def get_token_score(self, token_id: int) -> float:
|
||||
return -1000.0
|
||||
|
||||
def added_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
if text in self.specials:
|
||||
toktype = self.get_token_type(
|
||||
@@ -87,7 +89,7 @@ class HfVocab:
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, TokenType]]:
|
||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
|
Reference in New Issue
Block a user