diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index ee026363..b61aecaf 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -200,3 +200,10 @@ class Model(nn.Module): ): out, cache = self.model(inputs, cache) return self.lm_head(out), cache + + @staticmethod + def sanitize(weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f134bb70..692326fe 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -122,7 +122,7 @@ def generate( tokens = [] skip = 0 - REPLACEMENT_CHAR = '\ufffd' + REPLACEMENT_CHAR = "\ufffd" for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): if token == tokenizer.eos_token_id: @@ -136,7 +136,7 @@ def generate( print(s[skip:], end="", flush=True) skip = len(s) - tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '') + tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") if verbose: print(tokens[skip:], flush=True) return tokens @@ -174,6 +174,8 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: weights.update(mx.load(wf)) model_class, model_args_class = _get_classes(config=config) + if hasattr(model_class, "sanitize"): + weights = model_class.sanitize(weights) model_args = model_args_class.from_dict(config) model = model_class(model_args) diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index bb4c9442..7a5d5eb2 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -69,5 +69,5 @@ if __name__ == "__main__": x = (x * 255).astype(mx.uint8) # Save them to disc - im = Image.fromarray(x.__array__()) + im = Image.fromarray(np.array(x)) im.save(args.output)