mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
chore(clip): update the clip example to make it compatible with HF format (#472)
* chore(clip): update the clip model to be HF format * Update clip/convert.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: address comments * chore: rename ClipVisionModel and ClipTextModel * chore: add output hidden_states support * chore: remove custom conv2d and apply weight transpose during weight sanitizing * Update clip/model.py * Update clip/model.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
12
clip/test.py
12
clip/test.py
@@ -86,12 +86,14 @@ class TestCLIP(unittest.TestCase):
|
||||
with torch.inference_mode():
|
||||
# Get expected
|
||||
x_tc = torch.tensor(x)
|
||||
expected_out = self.hf_clip.vision_model(x_tc)
|
||||
expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
|
||||
expected_last_hidden = expected_out.last_hidden_state.numpy()
|
||||
expected_pooler_output = expected_out.pooler_output.numpy()
|
||||
|
||||
expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
|
||||
# Test MLX vision encoder
|
||||
out = self.mx_clip.vision_model(mx.array(x.transpose(0, 2, 3, 1)))
|
||||
out = self.mx_clip.vision_model(
|
||||
mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
|
||||
@@ -102,6 +104,10 @@ class TestCLIP(unittest.TestCase):
|
||||
out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
|
||||
),
|
||||
)
|
||||
for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
|
||||
self.assertTrue(
|
||||
np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
|
||||
)
|
||||
|
||||
def test_clip_model(self):
|
||||
image_input = self.hf_image_proc(
|
||||
|
||||
Reference in New Issue
Block a user