mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Apply formatter
This commit is contained in:
parent
07cf4336b3
commit
fb5e225523
@ -1917,7 +1917,7 @@ class Model(PlamoPreTrainedModel):
|
|||||||
self.lm_head: nn.Module = nn.Linear(
|
self.lm_head: nn.Module = nn.Linear(
|
||||||
config.hidden_size, vocab_size, bias=False
|
config.hidden_size, vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self._past_key_values: Optional[tuple[tuple[mx.array]]] = None
|
self._past_key_values: Optional[tuple[tuple[mx.array]]] = None
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@ -1940,19 +1940,14 @@ class Model(PlamoPreTrainedModel):
|
|||||||
|
|
||||||
def get_decoder(self) -> PlamoModel:
|
def get_decoder(self) -> PlamoModel:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||||
weights[k] = v.moveaxis(2, 1)
|
weights[k] = v.moveaxis(2, 1)
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
def make_cache(self) -> PlamoCache:
|
|
||||||
print("make_cache")
|
|
||||||
return "a"
|
|
||||||
|
|
||||||
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
||||||
print(cache)
|
|
||||||
output = self.forward(
|
output = self.forward(
|
||||||
input_ids=inputs,
|
input_ids=inputs,
|
||||||
use_cache=self.config.use_cache,
|
use_cache=self.config.use_cache,
|
||||||
@ -2117,4 +2112,4 @@ class Model(PlamoPreTrainedModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
Loading…
Reference in New Issue
Block a user