mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
chore(clip): update the clip example to make it compatible with HF format (#472)
* chore(clip): update the clip model to be HF format * Update clip/convert.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: address comments * chore: rename ClipVisionModel and ClipTextModel * chore: add output hidden_states support * chore: remove custom conv2d and apply weight transpose during weight sanitizing * Update clip/model.py * Update clip/model.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
f24edfa9dc
commit
47dd6bd17f
107
clip/convert.py
107
clip/convert.py
@ -1,15 +1,68 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
"""Save model weights into specified directory."""
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = make_shards(weights)
|
||||
shards_count = len(shards)
|
||||
shard_file_format = (
|
||||
"model-{:05d}-of-{:05d}.safetensors"
|
||||
if shards_count > 1
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
shard_path = save_path / shard_name
|
||||
|
||||
mx.save_safetensors(str(shard_path), shard)
|
||||
|
||||
for weight_name in shard.keys():
|
||||
index_data["weight_map"][weight_name] = shard_name
|
||||
|
||||
index_data["weight_map"] = {
|
||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||
}
|
||||
|
||||
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(
|
||||
index_data,
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||
|
||||
|
||||
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
|
||||
key = key.replace("embeddings.", "")
|
||||
key = key.replace("encoder.", "")
|
||||
key = key.replace("position_embedding.weight", "position_embedding")
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
if "layer_norm1." in key:
|
||||
key = key.replace("layer_norm1.", "ln1.")
|
||||
if "layer_norm2." in key:
|
||||
key = key.replace("layer_norm2.", "ln2.")
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
# Fix layernorm typo
|
||||
if "pre_layrnorm" in key:
|
||||
# Fix typo in weights :)
|
||||
key = key.replace("pre_layrnorm", "pre_layernorm")
|
||||
if "patch_embedding.weight" in key:
|
||||
# Initially, value: [out_channels, in_channels, kH, KW].
|
||||
# We want [out_channels, kH, KW, in_channels]
|
||||
value = value.permute(0, 2, 3, 1)
|
||||
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))
|
||||
|
||||
|
||||
def should_keep_weight(key: str):
|
||||
return not ("position_ids" in key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download and Convert (OpenAI) CLIP weights to MLX"
|
||||
@ -86,7 +101,12 @@ if __name__ == "__main__":
|
||||
default="mlx_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="The data type to save the converted model.",
|
||||
type=str,
|
||||
default="float32",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_path = get_model_path(args.hf_repo)
|
||||
@ -96,10 +116,11 @@ if __name__ == "__main__":
|
||||
print("[INFO] Loading")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
||||
print("[INFO] Converting")
|
||||
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items())
|
||||
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
|
||||
mlx_weights = {
|
||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||
}
|
||||
print("[INFO] Saving")
|
||||
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
|
||||
save_weights(mlx_path, mlx_weights)
|
||||
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
|
||||
shutil.copyfile(
|
||||
str(torch_path / f"{fn}"),
|
||||
|
250
clip/model.py
250
clip/model.py
@ -1,9 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -16,6 +19,7 @@ from mlx.utils import tree_flatten
|
||||
class CLIPVisionOutput:
|
||||
pooler_output: mx.array
|
||||
last_hidden_state: mx.array
|
||||
hidden_states: Optional[mx.array]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -41,6 +45,7 @@ class CLIPTextConfig:
|
||||
num_attention_heads: int
|
||||
max_position_embeddings: int
|
||||
vocab_size: int
|
||||
layer_norm_eps: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -52,6 +57,7 @@ class CLIPVisionConfig:
|
||||
num_channels: int
|
||||
image_size: int
|
||||
patch_size: int
|
||||
layer_norm_eps: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -75,50 +81,133 @@ def clip_loss(logits: mx.array) -> mx.array:
|
||||
return (caption_loss + image_loss) / 2.0
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.TransformerEncoderLayer):
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
"The input feature dimensions should be divisible by the "
|
||||
f"number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
||||
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
||||
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
||||
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.q_proj(queries)
|
||||
keys = self.k_proj(keys)
|
||||
values = self.v_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = quick_gelu
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, hidden_dim: int, intermediate_dim: int, num_heads: int):
|
||||
super().__init__(
|
||||
dims=hidden_dim,
|
||||
mlp_dims=intermediate_dim,
|
||||
num_heads=num_heads,
|
||||
activation=quick_gelu,
|
||||
norm_first=True,
|
||||
)
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
# Add biases to the attention projections
|
||||
self.attention = nn.MultiHeadAttention(hidden_dim, num_heads, bias=True)
|
||||
self.self_attn = Attention(
|
||||
config.hidden_size, config.num_attention_heads, bias=True
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = MLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
||||
y = self.layer_norm1(x)
|
||||
y = self.self_attn(y, y, y, mask)
|
||||
x = x + y
|
||||
y = self.layer_norm2(x)
|
||||
y = self.mlp(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, embed_dim
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
embeddings = self.token_embedding(x)
|
||||
embeddings += self.position_embedding.weight[: x.shape[1]]
|
||||
return embeddings
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
|
||||
class ClipTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embedding = mx.zeros(
|
||||
(config.max_position_embeddings, config.hidden_size)
|
||||
)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(
|
||||
config.hidden_size, config.intermediate_size, config.num_attention_heads
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
self.embeddings = TextEmbeddings(config)
|
||||
self.encoder = Encoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def _embed(self, x: mx.array) -> mx.array:
|
||||
embeddings = self.token_embedding(x)
|
||||
embeddings += self.position_embedding[: x.shape[1]]
|
||||
return embeddings
|
||||
|
||||
def __call__(self, x: mx.array) -> CLIPTextOutput:
|
||||
B, N = x.shape
|
||||
eot_tokens = mx.argmax(x, axis=-1)
|
||||
x = self._embed(x)
|
||||
x = self.embeddings(x)
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
||||
for l in self.layers:
|
||||
for l in self.encoder.layers:
|
||||
x = l(x, mask)
|
||||
last_hidden_state = self.final_layer_norm(x)
|
||||
pooler_output = last_hidden_state[mx.arange(B), eot_tokens]
|
||||
@ -128,33 +217,29 @@ class CLIPTextModel(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModel(nn.Module):
|
||||
"""Implements the vision encoder transformer from CLIP."""
|
||||
|
||||
class VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = mx.zeros((config.hidden_size,))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=config.patch_size,
|
||||
stride=config.patch_size,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = mx.zeros((num_positions, config.hidden_size))
|
||||
self.pre_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(
|
||||
config.hidden_size, config.intermediate_size, config.num_attention_heads
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def _embed(self, x: mx.array) -> mx.array:
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
batch_size = x.shape[0]
|
||||
# Patchify using conv:
|
||||
# [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim]
|
||||
@ -170,25 +255,48 @@ class CLIPVisionModel(nn.Module):
|
||||
# [batch_size, num_patches + 1, embed_dim]
|
||||
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
|
||||
# Add positional encoding
|
||||
embeddings += self.position_embedding
|
||||
embeddings += self.position_embedding.weight
|
||||
return embeddings
|
||||
|
||||
def __call__(self, x: mx.array) -> CLIPVisionOutput:
|
||||
x = self._embed(x)
|
||||
x = self.pre_layernorm(x)
|
||||
|
||||
for l in self.layers:
|
||||
class ClipVisionModel(nn.Module):
|
||||
"""Implements the vision encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__()
|
||||
self.embeddings = VisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
|
||||
self.encoder = Encoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> CLIPVisionOutput:
|
||||
x = self.embeddings(x)
|
||||
x = self.pre_layrnorm(x)
|
||||
|
||||
encoder_states = (x,) if output_hidden_states else None
|
||||
|
||||
for l in self.encoder.layers:
|
||||
x = l(x, mask=None)
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (x,)
|
||||
|
||||
# Extract <CLS> token embedding
|
||||
pooler_output = self.post_layernorm(x[:, 0, :])
|
||||
return CLIPVisionOutput(pooler_output=pooler_output, last_hidden_state=x)
|
||||
return CLIPVisionOutput(
|
||||
pooler_output=pooler_output,
|
||||
last_hidden_state=x,
|
||||
hidden_states=encoder_states,
|
||||
)
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, config: CLIPConfig):
|
||||
self.text_model = CLIPTextModel(config.text_config)
|
||||
self.vision_model = CLIPVisionModel(config.vision_config)
|
||||
self.text_model = ClipTextModel(config.text_config)
|
||||
self.vision_model = ClipVisionModel(config.vision_config)
|
||||
|
||||
text_embed_dim = config.text_config.hidden_size
|
||||
vision_embed_dim = config.vision_config.hidden_size
|
||||
@ -259,6 +367,7 @@ class CLIPModel(nn.Module):
|
||||
num_attention_heads=text_config["num_attention_heads"],
|
||||
max_position_embeddings=text_config["max_position_embeddings"],
|
||||
vocab_size=text_config["vocab_size"],
|
||||
layer_norm_eps=text_config["layer_norm_eps"],
|
||||
)
|
||||
|
||||
vision_config = config["vision_config"]
|
||||
@ -271,6 +380,7 @@ class CLIPModel(nn.Module):
|
||||
num_channels=3,
|
||||
image_size=vision_config["image_size"],
|
||||
patch_size=vision_config["patch_size"],
|
||||
layer_norm_eps=vision_config["layer_norm_eps"],
|
||||
)
|
||||
|
||||
config = CLIPConfig(
|
||||
@ -279,5 +389,31 @@ class CLIPModel(nn.Module):
|
||||
projection_dim=config["projection_dim"],
|
||||
)
|
||||
model = CLIPModel(config)
|
||||
model.load_weights(str(path / "weights.npz"))
|
||||
weight_files = glob.glob(str(path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
sanitized_weights = {}
|
||||
for k, v in weights.items():
|
||||
if "position_ids" in k:
|
||||
# Remove unused position_ids
|
||||
continue
|
||||
elif "patch_embedding.weight" in k:
|
||||
# pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW]
|
||||
# mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels]
|
||||
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
||||
else:
|
||||
sanitized_weights[k] = v
|
||||
|
||||
return sanitized_weights
|
||||
|
@ -3,4 +3,4 @@ numpy
|
||||
transformers
|
||||
torch
|
||||
huggingface_hub
|
||||
Pillow
|
||||
Pillow
|
||||
|
12
clip/test.py
12
clip/test.py
@ -86,12 +86,14 @@ class TestCLIP(unittest.TestCase):
|
||||
with torch.inference_mode():
|
||||
# Get expected
|
||||
x_tc = torch.tensor(x)
|
||||
expected_out = self.hf_clip.vision_model(x_tc)
|
||||
expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
|
||||
expected_last_hidden = expected_out.last_hidden_state.numpy()
|
||||
expected_pooler_output = expected_out.pooler_output.numpy()
|
||||
|
||||
expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
|
||||
# Test MLX vision encoder
|
||||
out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1)))
|
||||
out = self.mx_clip.vision_model(
|
||||
mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
|
||||
@ -102,6 +104,10 @@ class TestCLIP(unittest.TestCase):
|
||||
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
|
||||
),
|
||||
)
|
||||
for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
|
||||
self.assertTrue(
|
||||
np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
|
||||
)
|
||||
|
||||
def test_clip_model(self):
|
||||
image_input = self.hf_image_proc(
|
||||
|
Loading…
Reference in New Issue
Block a user