mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
Apply formatter
This commit is contained in:
parent
07cf4336b3
commit
fb5e225523
@ -1917,7 +1917,7 @@ class Model(PlamoPreTrainedModel):
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
config.hidden_size, vocab_size, bias=False
|
||||
)
|
||||
|
||||
|
||||
self._past_key_values: Optional[tuple[tuple[mx.array]]] = None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1940,19 +1940,14 @@ class Model(PlamoPreTrainedModel):
|
||||
|
||||
def get_decoder(self) -> PlamoModel:
|
||||
return self.model
|
||||
|
||||
|
||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self) -> PlamoCache:
|
||||
print("make_cache")
|
||||
return "a"
|
||||
|
||||
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
||||
print(cache)
|
||||
output = self.forward(
|
||||
input_ids=inputs,
|
||||
use_cache=self.config.use_cache,
|
||||
@ -2117,4 +2112,4 @@ class Model(PlamoPreTrainedModel):
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
return self.model.layers
|
||||
|
Loading…
Reference in New Issue
Block a user