mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Remove unused imports
This commit is contained in:
parent
8924bdc546
commit
71d7e99199
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
mamba_enabled: bool = True
|
mamba_enabled: bool = True
|
||||||
intermediate_size: int = 13312
|
intermediate_size: int = 13312
|
||||||
vocab_size: int = 32000
|
vocab_size: int = 32000
|
||||||
max_position_embeddings: int = 10 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -53,7 +52,7 @@ class RMSNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array:
|
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.astype(mx.float32)
|
hidden_states = hidden_states.astype(mx.float32)
|
||||||
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||||
@ -230,8 +229,7 @@ def ssd_chunk_scan_combined(
|
|||||||
|
|
||||||
|
|
||||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||||
batch, seqlen, dim = x.shape
|
_, seqlen, dim = x.shape
|
||||||
width = weight.shape[1]
|
|
||||||
state_len = conv_state.shape[-2]
|
state_len = conv_state.shape[-2]
|
||||||
x = mx.concatenate([conv_state, x], axis=-2)
|
x = mx.concatenate([conv_state, x], axis=-2)
|
||||||
conv_state = x[:, -state_len:]
|
conv_state = x[:, -state_len:]
|
||||||
|
Loading…
Reference in New Issue
Block a user