Remove unused imports

This commit is contained in:
Shunta Saito 2025-02-28 04:05:27 +09:00
parent 8924bdc546
commit 71d7e99199

View File

@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
mamba_enabled: bool = True
intermediate_size: int = 13312
vocab_size: int = 32000
max_position_embeddings: int = 10 * 1024 * 1024
class RMSNorm(nn.Module):
@ -53,7 +52,7 @@ class RMSNorm(nn.Module):
)
def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array:
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
@ -230,8 +229,7 @@ def ssd_chunk_scan_combined(
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
batch, seqlen, dim = x.shape
width = weight.shape[1]
_, seqlen, dim = x.shape
state_len = conv_state.shape[-2]
x = mx.concatenate([conv_state, x], axis=-2)
conv_state = x[:, -state_len:]