From 71d7e991990fd60e7957d4a2558584d21d57969f Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Fri, 28 Feb 2025 04:05:27 +0900 Subject: [PATCH] Remove unused imports --- llms/mlx_lm/models/plamo2.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index def84221..657fa02e 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs): mamba_enabled: bool = True intermediate_size: int = 13312 vocab_size: int = 32000 - max_position_embeddings: int = 10 * 1024 * 1024 class RMSNorm(nn.Module): @@ -53,7 +52,7 @@ class RMSNorm(nn.Module): ) -def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array: +def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array: input_dtype = hidden_states.dtype hidden_states = hidden_states.astype(mx.float32) variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) @@ -230,8 +229,7 @@ def ssd_chunk_scan_combined( def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: - batch, seqlen, dim = x.shape - width = weight.shape[1] + _, seqlen, dim = x.shape state_len = conv_state.shape[-2] x = mx.concatenate([conv_state, x], axis=-2) conv_state = x[:, -state_len:]