mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
revert accidental change
This commit is contained in:
parent
036090f508
commit
5a5decf767
@ -219,25 +219,15 @@ def run(bert_model: str, mlx_model: str):
|
||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
vs = model_configs[bert_model].vocab_size
|
||||
ts = np.random.randint(0, vs, (8, 512))
|
||||
tokens["input_ids"] = mx.array(ts)
|
||||
tokens["token_type_ids"] = mx.zeros((8, 512), mx.int32)
|
||||
tokens.pop("attention_mask")
|
||||
mlx_output, mlx_pooled = model(**tokens)
|
||||
mlx_output = numpy.array(mlx_output)
|
||||
mlx_pooled = numpy.array(mlx_pooled)
|
||||
|
||||
for _ in range(5):
|
||||
out = model(**tokens)
|
||||
mx.eval(out)
|
||||
print("MLX BERT:")
|
||||
print(mlx_output)
|
||||
|
||||
import time
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(10):
|
||||
out = model(**tokens)
|
||||
mx.eval(out)
|
||||
toc = time.time()
|
||||
tps = (8 * 5 * 10) / (toc - tic)
|
||||
print(tps)
|
||||
print("\n\nMLX Pooled:")
|
||||
print(mlx_pooled[0, :20])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user