mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
revert accidental change
This commit is contained in:
parent
036090f508
commit
5a5decf767
@ -219,25 +219,15 @@ def run(bert_model: str, mlx_model: str):
|
|||||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||||
|
|
||||||
vs = model_configs[bert_model].vocab_size
|
mlx_output, mlx_pooled = model(**tokens)
|
||||||
ts = np.random.randint(0, vs, (8, 512))
|
mlx_output = numpy.array(mlx_output)
|
||||||
tokens["input_ids"] = mx.array(ts)
|
mlx_pooled = numpy.array(mlx_pooled)
|
||||||
tokens["token_type_ids"] = mx.zeros((8, 512), mx.int32)
|
|
||||||
tokens.pop("attention_mask")
|
|
||||||
|
|
||||||
for _ in range(5):
|
print("MLX BERT:")
|
||||||
out = model(**tokens)
|
print(mlx_output)
|
||||||
mx.eval(out)
|
|
||||||
|
|
||||||
import time
|
print("\n\nMLX Pooled:")
|
||||||
|
print(mlx_pooled[0, :20])
|
||||||
tic = time.time()
|
|
||||||
for _ in range(10):
|
|
||||||
out = model(**tokens)
|
|
||||||
mx.eval(out)
|
|
||||||
toc = time.time()
|
|
||||||
tps = (8 * 5 * 10) / (toc - tic)
|
|
||||||
print(tps)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user