diff --git a/bert/model.py b/bert/model.py index 00344ab6..794254f6 100644 --- a/bert/model.py +++ b/bert/model.py @@ -219,25 +219,15 @@ def run(bert_model: str, mlx_model: str): tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()} - vs = model_configs[bert_model].vocab_size - ts = np.random.randint(0, vs, (8, 512)) - tokens["input_ids"] = mx.array(ts) - tokens["token_type_ids"] = mx.zeros((8, 512), mx.int32) - tokens.pop("attention_mask") + mlx_output, mlx_pooled = model(**tokens) + mlx_output = numpy.array(mlx_output) + mlx_pooled = numpy.array(mlx_pooled) - for _ in range(5): - out = model(**tokens) - mx.eval(out) + print("MLX BERT:") + print(mlx_output) - import time - - tic = time.time() - for _ in range(10): - out = model(**tokens) - mx.eval(out) - toc = time.time() - tps = (8 * 5 * 10) / (toc - tic) - print(tps) + print("\n\nMLX Pooled:") + print(mlx_pooled[0, :20]) if __name__ == "__main__":