From 7cfda327fd128428070813cee299b88e8238054f Mon Sep 17 00:00:00 2001
From: Anchen
Date: Tue, 9 Jan 2024 19:46:38 -0800
Subject: [PATCH] fix(lora): tokenizer return incompatible mx array (#271)
* fix(lora): tokenizer return incompatible encodeing mx array
* add readme nit
---------
Co-authored-by: Awni Hannun
---
lora/README.md | 2 +-
lora/lora.py | 6 ++----
lora/models.py | 29 ++---------------------------
3 files changed, 5 insertions(+), 32 deletions(-)
diff --git a/lora/README.md b/lora/README.md
index 7581aced..c4a80341 100644
--- a/lora/README.md
+++ b/lora/README.md
@@ -162,7 +162,7 @@ useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```
-python fuse.py --upload My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
+python fuse.py --upload-name My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
```
## Custom Data
diff --git a/lora/lora.py b/lora/lora.py
index e895a87b..ee809b83 100644
--- a/lora/lora.py
+++ b/lora/lora.py
@@ -172,10 +172,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
- batch = [
- tokenizer.encode(dset[indices[i + j]], eos=True)
- for j in range(batch_size)
- ]
+ batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens
@@ -187,6 +184,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
+
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)
diff --git a/lora/models.py b/lora/models.py
index 3b7d4848..244d8f5a 100644
--- a/lora/models.py
+++ b/lora/models.py
@@ -52,32 +52,6 @@ class ModelArgs:
)
-class Tokenizer:
- def __init__(self, model_path: str):
- self._tokenizer = AutoTokenizer.from_pretrained(model_path)
- self._eos = self._tokenizer.eos_token_id
- self._bos = self._tokenizer.bos_token_id
-
- def encode(self, s: str, eos: bool = False) -> mx.array:
- toks = self._tokenizer(
- s,
- return_tensors="np",
- return_attention_mask=False,
- )[
- "input_ids"
- ][0]
- if eos:
- toks = np.concatenate([toks, [self._eos]])
- return mx.array(toks)
-
- @property
- def eos_id(self) -> int:
- return self._eos
-
- def decode(self, t: List[int]) -> str:
- return self._tokenizer.decode(t)
-
-
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
@@ -359,7 +333,8 @@ def load(path_or_hf_repo: str):
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
- return model, Tokenizer(model_path), config
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ return model, tokenizer, config
def generate(prompt: mx.array, model: Model, temp: float = 0.0):