Remove unused imports

This commit is contained in:
Shunta Saito 2025-02-28 04:05:27 +09:00
parent 8924bdc546
commit 71d7e99199

View File

@ -2,7 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
mamba_enabled: bool = True mamba_enabled: bool = True
intermediate_size: int = 13312 intermediate_size: int = 13312
vocab_size: int = 32000 vocab_size: int = 32000
max_position_embeddings: int = 10 * 1024 * 1024
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -53,7 +52,7 @@ class RMSNorm(nn.Module):
) )
def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array: def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32) hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
@ -230,8 +229,7 @@ def ssd_chunk_scan_combined(
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
batch, seqlen, dim = x.shape _, seqlen, dim = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-2] state_len = conv_state.shape[-2]
x = mx.concatenate([conv_state, x], axis=-2) x = mx.concatenate([conv_state, x], axis=-2)
conv_state = x[:, -state_len:] conv_state = x[:, -state_len:]