mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
Add state property to PlamoCache
This commit is contained in:
parent
21c0abaf23
commit
e2d9d619c4
@ -194,6 +194,10 @@ class PlamoAttentionCache(nn.Module):
|
|||||||
assert value.shape[2] == L
|
assert value.shape[2] == L
|
||||||
self.key = key
|
self.key = key
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> tuple[mx.array, mx.array]:
|
||||||
|
return self.key, self.value
|
||||||
|
|
||||||
|
|
||||||
class PlamoMambaCache(nn.Module):
|
class PlamoMambaCache(nn.Module):
|
||||||
@ -207,6 +211,9 @@ class PlamoMambaCache(nn.Module):
|
|||||||
self.conv_state = conv_state
|
self.conv_state = conv_state
|
||||||
self.ssm_state = ssm_state
|
self.ssm_state = ssm_state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> tuple[mx.array, mx.array]:
|
||||||
|
return self.conv_state, self.ssm_state
|
||||||
|
|
||||||
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user