mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-02 06:41:13 +08:00
loading codestral works but no tinference
This commit is contained in:
parent
a6ddc27a4e
commit
38e5801edb
@ -30,6 +30,7 @@ class ModelArgs(BaseModelArgs):
|
||||
rms_norm: bool
|
||||
chunk_size: int
|
||||
tie_word_embeddings: bool
|
||||
intermediate_size: int = None
|
||||
use_cache: bool = True
|
||||
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
||||
time_step_rank: Union[int, str] = "auto"
|
||||
@ -93,20 +94,21 @@ class DepthWiseConv1d(nn.Module):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.groups = channels
|
||||
self.padding = padding
|
||||
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,)) if bias else None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
_, K, _ = self.weight.shape
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
y = mx.conv_general(x, self.weight, groups=self.groups)
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
@ -124,16 +126,20 @@ class Mamba2Block(nn.Module):
|
||||
self.d_state = args.state_size
|
||||
self.d_conv = args.conv_kernel
|
||||
self.expand = args.expand
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
if args.intermediate_size == None:
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
else:
|
||||
self.d_inner = args.intermediate_size
|
||||
self.n_groups = args.n_groups
|
||||
self.n_heads = args.num_heads
|
||||
self.d_head = self.d_inner // self.n_heads
|
||||
|
||||
# Input projection
|
||||
d_in_proj = self.d_inner * 2 + self.d_state * 2 + self.n_heads
|
||||
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
||||
|
||||
# Convolution
|
||||
conv_dim = self.d_inner + 2 * self.d_state
|
||||
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=conv_dim,
|
||||
kernel_size=self.d_conv,
|
||||
|
Loading…
Reference in New Issue
Block a user