chore(clip): update the clip example to make it compatible with HF format (#472)

* chore(clip): update the clip model to be HF format

* Update clip/convert.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: address comments

* chore: rename ClipVisionModel and ClipTextModel

* chore: add output hidden_states support

* chore: remove custom conv2d and apply weight transpose during weight sanitizing

* Update clip/model.py

* Update clip/model.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anchen
2024-02-24 01:49:53 +11:00
committed by GitHub
parent f24edfa9dc
commit 47dd6bd17f
4 changed files with 267 additions and 104 deletions

View File

@@ -86,12 +86,14 @@ class TestCLIP(unittest.TestCase):
with torch.inference_mode():
# Get expected
x_tc = torch.tensor(x)
expected_out = self.hf_clip.vision_model(x_tc)
expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
expected_last_hidden = expected_out.last_hidden_state.numpy()
expected_pooler_output = expected_out.pooler_output.numpy()
expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
# Test MLX vision encoder
out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1)))
out = self.mx_clip.vision_model(
mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
)
self.assertTrue(
np.allclose(
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
@@ -102,6 +104,10 @@ class TestCLIP(unittest.TestCase):
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
),
)
for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
self.assertTrue(
np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
)
def test_clip_model(self):
image_input = self.hf_image_proc(