Apply formatter

This commit is contained in:
Shunta Saito 2025-02-13 20:02:38 +09:00
parent 07cf4336b3
commit fb5e225523

View File

@ -1917,7 +1917,7 @@ class Model(PlamoPreTrainedModel):
self.lm_head: nn.Module = nn.Linear(
config.hidden_size, vocab_size, bias=False
)
self._past_key_values: Optional[tuple[tuple[mx.array]]] = None
# Initialize weights and apply final processing
@ -1940,19 +1940,14 @@ class Model(PlamoPreTrainedModel):
def get_decoder(self) -> PlamoModel:
return self.model
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self) -> PlamoCache:
print("make_cache")
return "a"
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
print(cache)
output = self.forward(
input_ids=inputs,
use_cache=self.config.use_cache,
@ -2117,4 +2112,4 @@ class Model(PlamoPreTrainedModel):
@property
def layers(self):
return self.model.layers
return self.model.layers