mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
update
This commit is contained in:
@@ -88,6 +88,32 @@ class Mamba2LMHeadModel(nn.Module):
|
||||
)
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(huggingface_model_id: str, device: Device = None):
|
||||
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
|
||||
assert config_path, "Failed to get huggingface config file"
|
||||
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
|
||||
assert state_dict_path, "Failed to get huggingface state dict file"
|
||||
|
||||
config = json.load(open(config_path))
|
||||
args = Mamba2Config(
|
||||
d_model=config["d_model"],
|
||||
n_layer=config["n_layer"],
|
||||
vocab_size=config["vocab_size"],
|
||||
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
|
||||
)
|
||||
|
||||
map_location = "cpu" if device is None else device
|
||||
state_dict = torch.load(
|
||||
state_dict_path, weights_only=True, map_location=map_location, mmap=True
|
||||
)
|
||||
model = Mamba2LMHeadModel(args, device=device)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
|
||||
@@ -193,7 +219,6 @@ class Mamba2(nn.Module):
|
||||
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
|
||||
self.norm = RMSNorm(args.d_inner, device=device)
|
||||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user