mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
fix(lora): tokenizer return incompatible mx array (#271)
* fix(lora): tokenizer return incompatible encodeing mx array * add readme nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
7b258f33ac
commit
7cfda327fd
@ -162,7 +162,7 @@ useful for the sake of attribution and model versioning.
|
|||||||
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||||
|
|
||||||
```
|
```
|
||||||
python fuse.py --upload My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
|
python fuse.py --upload-name My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
## Custom Data
|
## Custom Data
|
||||||
|
@ -172,10 +172,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
# Collect batches from dataset
|
# Collect batches from dataset
|
||||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||||
# Encode batch
|
# Encode batch
|
||||||
batch = [
|
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
||||||
tokenizer.encode(dset[indices[i + j]], eos=True)
|
|
||||||
for j in range(batch_size)
|
|
||||||
]
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
# Check if any sequence is longer than 2048 tokens
|
# Check if any sequence is longer than 2048 tokens
|
||||||
@ -187,6 +184,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
|
|
||||||
# Pad to the max length
|
# Pad to the max length
|
||||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||||
|
|
||||||
for j in range(batch_size):
|
for j in range(batch_size):
|
||||||
batch_arr[j, : lengths[j]] = batch[j]
|
batch_arr[j, : lengths[j]] = batch[j]
|
||||||
batch = mx.array(batch_arr)
|
batch = mx.array(batch_arr)
|
||||||
|
@ -52,32 +52,6 @@ class ModelArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
|
||||||
def __init__(self, model_path: str):
|
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
||||||
self._eos = self._tokenizer.eos_token_id
|
|
||||||
self._bos = self._tokenizer.bos_token_id
|
|
||||||
|
|
||||||
def encode(self, s: str, eos: bool = False) -> mx.array:
|
|
||||||
toks = self._tokenizer(
|
|
||||||
s,
|
|
||||||
return_tensors="np",
|
|
||||||
return_attention_mask=False,
|
|
||||||
)[
|
|
||||||
"input_ids"
|
|
||||||
][0]
|
|
||||||
if eos:
|
|
||||||
toks = np.concatenate([toks, [self._eos]])
|
|
||||||
return mx.array(toks)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_id(self) -> int:
|
|
||||||
return self._eos
|
|
||||||
|
|
||||||
def decode(self, t: List[int]) -> str:
|
|
||||||
return self._tokenizer.decode(t)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
class LoRALinear(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_linear(linear: nn.Linear, rank: int = 8):
|
def from_linear(linear: nn.Linear, rank: int = 8):
|
||||||
@ -359,7 +333,8 @@ def load(path_or_hf_repo: str):
|
|||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model, Tokenizer(model_path), config
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
return model, tokenizer, config
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
|
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
|
||||||
|
Loading…
Reference in New Issue
Block a user