mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
parent
db109184b7
commit
d4ef909d4a
@ -23,7 +23,12 @@ class BaseModelArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
def create_causal_mask(
|
||||||
|
N: int,
|
||||||
|
offset: int = 0,
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
lengths: Optional[mx.array] = None,
|
||||||
|
):
|
||||||
rinds = mx.arange(offset + N)
|
rinds = mx.arange(offset + N)
|
||||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
linds = linds[:, None]
|
linds = linds[:, None]
|
||||||
@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
|||||||
mask = linds < rinds
|
mask = linds < rinds
|
||||||
if window_size is not None:
|
if window_size is not None:
|
||||||
mask = mask | (linds > rinds + window_size)
|
mask = mask | (linds > rinds + window_size)
|
||||||
|
if lengths is not None:
|
||||||
|
lengths = lengths[:, None, None, None]
|
||||||
|
mask = mask | (rinds >= lengths)
|
||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,10 +155,12 @@ class CohereModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -180,9 +182,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out
|
return out
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .cache import KVCache, RotatingKVCache
|
from .cache import KVCache, RotatingKVCache
|
||||||
|
|
||||||
|
|
||||||
@ -151,16 +151,13 @@ class CohereModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
T = h.shape[1]
|
if mask is None:
|
||||||
if T > 1:
|
mask = create_attention_mask(h, cache)
|
||||||
offset = cache[0].offset if cache else 0
|
|
||||||
mask = create_causal_mask(T, offset).astype(h.dtype)
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -181,9 +178,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out
|
return out
|
||||||
|
@ -197,10 +197,12 @@ class DBRX(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.wte(inputs)
|
h = self.wte(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -223,9 +225,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -211,8 +211,10 @@ class DeepseekModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -236,8 +238,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -370,8 +370,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -395,8 +398,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -123,9 +123,11 @@ class ExaoneModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.wte(inputs)
|
h = self.wte(inputs)
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -149,9 +151,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.transformer.wte.as_linear(out)
|
out = self.transformer.wte.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -138,11 +138,13 @@ class GemmaModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
h = h * (self.args.hidden_size**0.5)
|
h = h * (self.args.hidden_size**0.5)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -164,9 +166,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -160,11 +160,13 @@ class GemmaModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
h = h * (self.args.hidden_size**0.5)
|
h = h * (self.args.hidden_size**0.5)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -187,9 +189,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = mx.tanh(out / self.final_logit_softcapping)
|
out = mx.tanh(out / self.final_logit_softcapping)
|
||||||
out = out * self.final_logit_softcapping
|
out = out * self.final_logit_softcapping
|
||||||
|
@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
_, L = inputs.shape
|
_, L = inputs.shape
|
||||||
@ -138,6 +139,7 @@ class GPT2Model(nn.Module):
|
|||||||
position_ids = mx.array(np.arange(L))
|
position_ids = mx.array(np.arange(L))
|
||||||
hidden_states += self.wpe(position_ids)
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -159,9 +161,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.wte.as_linear(out)
|
out = self.model.wte.as_linear(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
B, L = inputs.shape
|
B, L = inputs.shape
|
||||||
@ -149,6 +150,7 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
position_ids = mx.array(np.arange(L))
|
position_ids = mx.array(np.arange(L))
|
||||||
hidden_states += self.wpe(position_ids)
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -172,9 +174,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.transformer.wte.as_linear(out)
|
out = self.transformer.wte.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -146,12 +146,14 @@ class GPTNeoXModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
_, L = inputs.shape
|
_, L = inputs.shape
|
||||||
|
|
||||||
hidden_states = self.embed_in(inputs)
|
hidden_states = self.embed_in(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -176,9 +178,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -239,10 +239,12 @@ class HunYuanModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -266,9 +268,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.model.embed_tokens.as_linear(out)
|
return self.model.embed_tokens.as_linear(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -193,10 +193,12 @@ class InternLM2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.tok_embeddings(inputs)
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -220,9 +222,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.tok_embeddings.as_linear(out)
|
out = self.model.tok_embeddings.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -155,10 +155,12 @@ class LlamaModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -182,9 +184,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -158,10 +158,12 @@ class MiniCPMModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -186,9 +188,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
|
|
||||||
if not self.args.tie_word_embeddings:
|
if not self.args.tie_word_embeddings:
|
||||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||||
|
@ -162,10 +162,12 @@ class MixtralModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -188,9 +190,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -176,10 +176,12 @@ class NemotronModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -203,9 +205,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -124,10 +124,12 @@ class Transformer(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.wte(inputs)
|
h = self.wte(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -152,9 +154,10 @@ class OlmoModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
return self.transformer(inputs, cache)
|
return self.transformer(inputs, mask, cache)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -167,9 +170,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
return self.model(inputs, cache)
|
return self.model(inputs, mask, cache)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -163,9 +163,11 @@ class LlamaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
|
mask=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -190,8 +192,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
|
mask=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -178,10 +178,12 @@ class OpenELMModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.token_embeddings(inputs)
|
h = self.token_embeddings(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -205,9 +207,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
if self.args.share_input_output_layers:
|
if self.args.share_input_output_layers:
|
||||||
out = self.transformer.token_embeddings.as_linear(out)
|
out = self.transformer.token_embeddings.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -143,9 +143,10 @@ class PhiModel(nn.Module):
|
|||||||
config.hidden_size, eps=config.layer_norm_eps
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x, cache):
|
def __call__(self, x, mask, cache):
|
||||||
x = self.embed_tokens(x)
|
x = self.embed_tokens(x)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = create_attention_mask(x, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -167,9 +168,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
y = self.model(x, cache)
|
y = self.model(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -168,10 +168,12 @@ class Phi3Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -194,9 +196,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -258,12 +258,14 @@ class Phi3Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
if self.mup_embedding_multiplier:
|
if self.mup_embedding_multiplier:
|
||||||
h = self.mup_embedding_multiplier * h
|
h = self.mup_embedding_multiplier * h
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -290,9 +292,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
if self.mup_width_multiplier:
|
if self.mup_width_multiplier:
|
||||||
out = out / self.mup_width_multiplier
|
out = out / self.mup_width_multiplier
|
||||||
|
@ -155,10 +155,12 @@ class PhiMoEModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -181,9 +183,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -175,6 +175,8 @@ class Model(nn.Module):
|
|||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = create_attention_mask(x, cache)
|
||||||
|
|
||||||
y = self.transformer(x, mask, cache)
|
y = self.transformer(x, mask, cache)
|
||||||
|
@ -174,9 +174,11 @@ class PlamoModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -202,8 +204,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -123,6 +123,7 @@ class QwenModel(nn.Module):
|
|||||||
def __call__(self, inputs, mask=None, cache=None):
|
def __call__(self, inputs, mask=None, cache=None):
|
||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = create_attention_mask(x, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
|
@ -149,10 +149,12 @@ class Qwen2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -176,9 +178,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -187,10 +187,12 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -213,9 +215,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -389,6 +389,7 @@ class Griffin(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
tokens,
|
tokens,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
x = self.embed_tokens(tokens)
|
x = self.embed_tokens(tokens)
|
||||||
@ -402,6 +403,7 @@ class Griffin(nn.Module):
|
|||||||
if block.temporal_block_type != "recurrent":
|
if block.temporal_block_type != "recurrent":
|
||||||
mask_cache = [cache[i]]
|
mask_cache = [cache[i]]
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(x, mask_cache)
|
mask = create_attention_mask(x, mask_cache)
|
||||||
|
|
||||||
for i, block in enumerate(self.layers):
|
for i, block in enumerate(self.layers):
|
||||||
@ -418,12 +420,12 @@ class Model(nn.Module):
|
|||||||
self.model_type = config.model_type
|
self.model_type = config.model_type
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokens: Sequence of input tokens.
|
tokens: Sequence of input tokens.
|
||||||
"""
|
"""
|
||||||
logits = self.model(tokens, cache=cache)
|
logits = self.model(tokens, mask=mask, cache=cache)
|
||||||
if "lm_head" in self:
|
if "lm_head" in self:
|
||||||
logits = self.lm_head(logits)
|
logits = self.lm_head(logits)
|
||||||
else:
|
else:
|
||||||
|
@ -199,7 +199,10 @@ class Model(nn.Module):
|
|||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = create_attention_mask(x, cache)
|
||||||
|
|
||||||
y = self.model(x, mask, cache)
|
y = self.model(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
|
||||||
|
@ -125,10 +125,12 @@ class Starcoder2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -152,9 +154,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models import rope_utils
|
from mlx_lm.models import rope_utils
|
||||||
|
from mlx_lm.models.base import create_causal_mask
|
||||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||||
|
|
||||||
|
|
||||||
@ -128,6 +129,22 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(cache.offset, 22)
|
self.assertEqual(cache.offset, 22)
|
||||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||||
|
|
||||||
|
def test_causal_mask_lengths(self):
|
||||||
|
mx.random.seed(8)
|
||||||
|
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||||
|
lengths = mx.array([1, 2, 3, 1])
|
||||||
|
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||||
|
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||||
|
v = k
|
||||||
|
mask = create_causal_mask(T_q, 0, lengths=lengths)
|
||||||
|
|
||||||
|
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||||
|
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
|
||||||
|
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
|
||||||
|
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
|
||||||
|
|
||||||
def test_rope(self):
|
def test_rope(self):
|
||||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||||
@ -162,7 +179,13 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
cache = make_prompt_cache(model)
|
cache = make_prompt_cache(model)
|
||||||
outputs = model(inputs, cache)
|
outputs = model(inputs, cache=cache)
|
||||||
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
|
if model_type != "mamba":
|
||||||
|
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||||
|
outputs = model(inputs, mask=mask)
|
||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user