mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
parent
4fa659acbd
commit
37b41cec60
@ -14,7 +14,7 @@ Some more useful examples are listed below.
|
|||||||
[Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
|
[Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
|
||||||
directory.
|
directory.
|
||||||
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
|
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
|
||||||
- Parameter efficient fine-tuning with [LoRA](lora).
|
- Parameter efficient fine-tuning with [LoRA or QLoRA](lora).
|
||||||
- Text-to-text multi-task Transformers with [T5](t5).
|
- Text-to-text multi-task Transformers with [T5](t5).
|
||||||
- Bidirectional language understanding with [BERT](bert).
|
- Bidirectional language understanding with [BERT](bert).
|
||||||
|
|
||||||
|
@ -32,6 +32,10 @@ page](https://huggingface.co/deepseek-ai) to see a list of available models.
|
|||||||
By default, the conversion script will save the converted `weights.npz`,
|
By default, the conversion script will save the converted `weights.npz`,
|
||||||
tokenizer, and `config.json` in the `mlx_model` directory.
|
tokenizer, and `config.json` in the `mlx_model` directory.
|
||||||
|
|
||||||
|
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||||
|
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||||
|
> Hugging Face and skip the conversion step.
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
Once you've converted the weights, you can interact with the Deepseek coder
|
Once you've converted the weights, you can interact with the Deepseek coder
|
||||||
|
@ -14,11 +14,13 @@ import torch
|
|||||||
from llama import Llama, ModelArgs, sanitize_config
|
from llama import Llama, ModelArgs, sanitize_config
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||||
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
||||||
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
|
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
|
||||||
return mx.array(a.numpy(), getattr(mx, dtype))
|
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||||
|
|
||||||
|
|
||||||
def llama(model_path, *, dtype: str):
|
def llama(model_path, *, dtype: str):
|
||||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||||
@ -48,7 +50,7 @@ def llama(model_path, *, dtype: str):
|
|||||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
v = torch_to_mx(v, dtype=dtype)
|
v = torch_to_mx(v, dtype=dtype)
|
||||||
state[k] = None # free memory
|
state[k] = None # free memory
|
||||||
if shard_key(k) in SHARD_WEIGHTS:
|
if shard_key(k) in SHARD_WEIGHTS:
|
||||||
weights[k].append(v)
|
weights[k].append(v)
|
||||||
else:
|
else:
|
||||||
@ -204,7 +206,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--dtype",
|
||||||
help="dtype for loading the torch model and input for quantization or saving the converted model. "
|
help="dtype for loading the torch model and input for quantization or saving the converted model. "
|
||||||
"The original weights are stored in bfloat16.",
|
"The original weights are stored in bfloat16.",
|
||||||
type=str,
|
type=str,
|
||||||
default="float16",
|
default="float16",
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# LoRA
|
# Fine-Tuning with LoRA or QLoRA
|
||||||
|
|
||||||
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
|
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
|
||||||
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
|
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
|
||||||
task.
|
task. The example also supports quantized LoRA (QLoRA).[^qlora]
|
||||||
|
|
||||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||||
generate SQL queries from natural language. However, the example is intended to
|
generate SQL queries from natural language. However, the example is intended to
|
||||||
@ -43,10 +43,13 @@ Convert the model with:
|
|||||||
|
|
||||||
```
|
```
|
||||||
python convert.py \
|
python convert.py \
|
||||||
--torch-model <path_to_torch_model> \
|
--torch-path <path_to_torch_model> \
|
||||||
--mlx-model <path_to_mlx_model>
|
--mlx-path <path_to_mlx_model>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you wish to use QLoRA, then convert the model with 4-bit quantization using
|
||||||
|
the `-q` option.
|
||||||
|
|
||||||
## Run
|
## Run
|
||||||
|
|
||||||
The main script is `lora.py`. To see a full list of options run
|
The main script is `lora.py`. To see a full list of options run
|
||||||
@ -65,8 +68,11 @@ python lora.py --model <path_to_model> \
|
|||||||
--iters 600
|
--iters 600
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||||
|
otherwise it will use regular LoRA.
|
||||||
|
|
||||||
Note, the model path should have the MLX weights, the tokenizer, and the
|
Note, the model path should have the MLX weights, the tokenizer, and the
|
||||||
`params.json` configuration which will all be output by the `convert.py` script.
|
`config.json` which will all be output by the `convert.py` script.
|
||||||
|
|
||||||
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
||||||
the output location with `--adapter-file`.
|
the output location with `--adapter-file`.
|
||||||
@ -137,16 +143,20 @@ Note other keys will be ignored by the loader.
|
|||||||
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||||
of memory. Here are some tips to reduce memory use should you need to do so:
|
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||||
|
|
||||||
1. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||||
|
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||||
|
more details.
|
||||||
|
|
||||||
|
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||||
things down a little, but will also reduce the memory use.
|
things down a little, but will also reduce the memory use.
|
||||||
|
|
||||||
2. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
||||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||||
needed for back propagation. It may also reduce the quality of the
|
needed for back propagation. It may also reduce the quality of the
|
||||||
fine-tuned model if you are fine-tuning with a lot of data.
|
fine-tuned model if you are fine-tuning with a lot of data.
|
||||||
|
|
||||||
3. Longer examples require more memory. If it makes sense for your data, one thing
|
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||||
you can do is break your examples into smaller
|
you can do is break your examples into smaller
|
||||||
sequences when making the `{train, valid, test}.jsonl` files.
|
sequences when making the `{train, valid, test}.jsonl` files.
|
||||||
|
|
||||||
@ -164,6 +174,7 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
|
|||||||
|
|
||||||
|
|
||||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||||
|
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||||
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||||
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
||||||
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
|
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
|
||||||
|
104
lora/convert.py
104
lora/convert.py
@ -1,69 +1,125 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
from lora import Model, ModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(weights, config, args):
|
||||||
|
quantized_config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# Load the model:
|
||||||
|
model = Model(ModelArgs(**config))
|
||||||
|
weights = tree_map(mx.array, weights)
|
||||||
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
|
||||||
|
# Quantize the model:
|
||||||
|
nn.QuantizedLinear.quantize_module(
|
||||||
|
model,
|
||||||
|
args.q_group_size,
|
||||||
|
args.q_bits,
|
||||||
|
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
||||||
|
and m.weight.shape[0] != config["vocab_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the config:
|
||||||
|
quantized_config["quantization"] = {
|
||||||
|
"group_size": args.q_group_size,
|
||||||
|
"bits": args.q_bits,
|
||||||
|
}
|
||||||
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Convert Mistral or Llama models to MLX.",
|
description="Convert Mistral or Llama models to MLX.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch-model",
|
"--torch-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="mistral-7B-v0.1/",
|
default="mistral-7B-v0.1/",
|
||||||
help="The torch model directory",
|
help="Path to the torch model directory",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mlx-model",
|
"--mlx-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="mlx-mistral-7B-v0.1/",
|
default="mlx_model/",
|
||||||
help="The directory to store the mlx model",
|
help="The directory to store the mlx model",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-q",
|
||||||
|
"--quantize",
|
||||||
|
help="Generate a quantized model.",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-group-size",
|
||||||
|
help="Group size for quantization.",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-bits",
|
||||||
|
help="Bits per weight for quantization.",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
torch_path = Path(args.torch_model)
|
args = parser.parse_args()
|
||||||
if not os.path.exists(args.mlx_model):
|
|
||||||
os.makedirs(args.mlx_model)
|
torch_path = Path(args.torch_path)
|
||||||
mlx_path = Path(args.mlx_model)
|
mlx_path = Path(args.mlx_path)
|
||||||
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Copy the tokenizer
|
# Copy the tokenizer
|
||||||
tokenizer_path = torch_path / "tokenizer.model"
|
tokenizer_path = torch_path / "tokenizer.model"
|
||||||
if not tokenizer_path.exists():
|
if not tokenizer_path.exists():
|
||||||
print(f"Make sure there is a file tokenizer.model in {args.torch_model}")
|
print(f"Make sure there is a file tokenizer.model in {args.torch-path}")
|
||||||
exit(0)
|
exit(0)
|
||||||
shutil.copyfile(
|
shutil.copyfile(
|
||||||
str(tokenizer_path),
|
str(tokenizer_path),
|
||||||
str(mlx_path / "tokenizer.model"),
|
str(mlx_path / "tokenizer.model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy the model weights
|
# Load the torch model weights to numpy:
|
||||||
state = torch.load(str(torch_path / "consolidated.00.pth"))
|
weights = torch.load(str(torch_path / "consolidated.00.pth"))
|
||||||
np.savez(
|
for k, v in weights.items():
|
||||||
str(mlx_path / "weights.npz"),
|
weights[k] = v.to(torch.float16).numpy()
|
||||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy the params
|
# Standardize the params
|
||||||
with open(torch_path / "params.json", "r") as f:
|
with open(torch_path / "params.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
unused = ["multiple_of"]
|
unused = ["multiple_of", "sliding_window"]
|
||||||
for k in unused:
|
for k in unused:
|
||||||
if k in config:
|
config.pop(k, None)
|
||||||
config.pop(k)
|
|
||||||
n_heads = config["n_heads"]
|
n_heads = config["n_heads"]
|
||||||
if "sliding_window" in config:
|
|
||||||
config.pop("sliding_window")
|
|
||||||
if "n_kv_heads" not in config:
|
if "n_kv_heads" not in config:
|
||||||
config["n_kv_heads"] = n_heads
|
config["n_kv_heads"] = n_heads
|
||||||
if "head_dim" not in config:
|
if "head_dim" not in config:
|
||||||
config["head_dim"] = config["dim"] // n_heads
|
config["head_dim"] = config["dim"] // n_heads
|
||||||
if "hidden_dim" not in config:
|
if "hidden_dim" not in config:
|
||||||
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0]
|
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||||
with open(mlx_path / "params.json", "w") as outfile:
|
if config.get("vocab_size", -1) < 0:
|
||||||
|
config["vocab_size"] = weights["output.weight"].shape[0]
|
||||||
|
|
||||||
|
if args.quantize:
|
||||||
|
print("[INFO] Quantizing")
|
||||||
|
weights, config = quantize(weights, config, args)
|
||||||
|
|
||||||
|
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||||
|
|
||||||
|
with open(mlx_path / "config.json", "w") as outfile:
|
||||||
json.dump(config, outfile, indent=4)
|
json.dump(config, outfile, indent=4)
|
||||||
|
31
lora/lora.py
31
lora/lora.py
@ -17,12 +17,10 @@ from sentencepiece import SentencePieceProcessor
|
|||||||
|
|
||||||
|
|
||||||
def build_parser():
|
def build_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
description="LoRA finetuning with Llama or Mistral"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
required=True,
|
default="mlx_model",
|
||||||
help="A path to the model files containing the tokenizer, weights, config.",
|
help="A path to the model files containing the tokenizer, weights, config.",
|
||||||
)
|
)
|
||||||
# Generation args
|
# Generation args
|
||||||
@ -332,18 +330,22 @@ def generate(model, prompt, tokenizer, args):
|
|||||||
print(s, flush=True)
|
print(s, flush=True)
|
||||||
|
|
||||||
|
|
||||||
def load_model(folder: str, dtype=mx.float16):
|
def load_model(folder: str):
|
||||||
model_path = Path(folder)
|
model_path = Path(folder)
|
||||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||||
with open(model_path / "params.json", "r") as f:
|
with open(model_path / "config.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
if config.get("vocab_size", -1) < 0:
|
quantization = config.pop("quantization", None)
|
||||||
config["vocab_size"] = tokenizer.vocab_size
|
|
||||||
model_args = ModelArgs(**config)
|
model_args = ModelArgs(**config)
|
||||||
|
model = Model(model_args)
|
||||||
|
if quantization is not None:
|
||||||
|
quantization["linear_class_predicate"] = lambda m: isinstance(
|
||||||
|
m, nn.Linear
|
||||||
|
) and (m.weight.shape[0] != model_args.vocab_size)
|
||||||
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||||
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
|
||||||
model = Model(model_args)
|
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -374,7 +376,7 @@ if __name__ == "__main__":
|
|||||||
# Resume training the given adapters.
|
# Resume training the given adapters.
|
||||||
if args.resume_adapter_file is not None:
|
if args.resume_adapter_file is not None:
|
||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||||
model.load_weights(args.resume_adapter_file)
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
print("Training")
|
print("Training")
|
||||||
@ -387,7 +389,12 @@ if __name__ == "__main__":
|
|||||||
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
|
mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters())))
|
||||||
|
|
||||||
# Load the LoRA adapter weights which we assume should exist by this point
|
# Load the LoRA adapter weights which we assume should exist by this point
|
||||||
model.load_weights(args.adapter_file)
|
if not Path(args.adapter_file).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter file {args.adapter_file} missing. "
|
||||||
|
"Use --train to learn and save the adapters.npz."
|
||||||
|
)
|
||||||
|
model.load_weights(args.adapter_file, strict=False)
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
print("Testing")
|
print("Testing")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@ -24,7 +23,11 @@ class ModelArgs:
|
|||||||
class LoRALinear(nn.Module):
|
class LoRALinear(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_linear(linear: nn.Linear, rank: int = 8):
|
def from_linear(linear: nn.Linear, rank: int = 8):
|
||||||
|
# TODO remove when input_dims and output_dims are attributes
|
||||||
|
# on linear and quantized linear
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
|
input_dims *= 32 // linear.bits
|
||||||
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
||||||
lora_lin.linear = linear
|
lora_lin.linear = linear
|
||||||
return lora_lin
|
return lora_lin
|
||||||
@ -47,7 +50,10 @@ class LoRALinear(nn.Module):
|
|||||||
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
y = self.linear(x.astype(self.linear.weight.dtype))
|
dtype = self.linear.weight.dtype
|
||||||
|
if isinstance(self.linear, nn.QuantizedLinear):
|
||||||
|
dtype = self.linear.scales.dtype
|
||||||
|
y = self.linear(x.astype(dtype))
|
||||||
z = (x @ self.lora_a) @ self.lora_b
|
z = (x @ self.lora_a) @ self.lora_b
|
||||||
return y + 2.0 * z
|
return y + 2.0 * z
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
mlx
|
mlx>=0.0.7
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
|
Loading…
Reference in New Issue
Block a user