mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:55:20 +08:00
Remove unused imports
This commit is contained in:
parent
8924bdc546
commit
71d7e99199
@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
|
||||
mamba_enabled: bool = True
|
||||
intermediate_size: int = 13312
|
||||
vocab_size: int = 32000
|
||||
max_position_embeddings: int = 10 * 1024 * 1024
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -53,7 +52,7 @@ class RMSNorm(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array:
|
||||
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.astype(mx.float32)
|
||||
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||
@ -230,8 +229,7 @@ def ssd_chunk_scan_combined(
|
||||
|
||||
|
||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||
batch, seqlen, dim = x.shape
|
||||
width = weight.shape[1]
|
||||
_, seqlen, dim = x.shape
|
||||
state_len = conv_state.shape[-2]
|
||||
x = mx.concatenate([conv_state, x], axis=-2)
|
||||
conv_state = x[:, -state_len:]
|
||||
|
Loading…
Reference in New Issue
Block a user