diff --git a/benchmarks/python/blas/bench_gemv.py b/benchmarks/python/blas/bench_gemv.py index 5f491ffc8..2b564a78a 100644 --- a/benchmarks/python/blas/bench_gemv.py +++ b/benchmarks/python/blas/bench_gemv.py @@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype): return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3) -def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose): +def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose): np_dtype = getattr(np, dtype) mlx_gb_s = [] mlx_gflops = [] @@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose): ax.legend() -def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose): +def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose): np_dtype = getattr(np, dtype) mlx_gb_s = [] mlx_gflops = [] diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1787cbd95..014707b38 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -517,7 +517,7 @@ array slice( // Gather moves the axis up, remainder needs to be squeezed out_reshape[i] = indices[i].size(); - // Gather moves the axis up, needs to be tranposed + // Gather moves the axis up, needs to be transposed t_axes[ax] = i; } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index a59052936..8a7d632fa 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -81,12 +81,12 @@ class TestBlas(mlx_tests.MLXTestCase): for B, M, N, K in shapes: - with self.subTest(tranpose="nn"): + with self.subTest(transpose="nn"): shape_a = (B, M, K) shape_b = (B, K, N) self.__gemm_test(shape_a, shape_b, np_dtype) - with self.subTest(tranpose="nt"): + with self.subTest(transpose="nt"): shape_a = (B, M, K) shape_b = (B, N, K) self.__gemm_test( @@ -97,7 +97,7 @@ class TestBlas(mlx_tests.MLXTestCase): f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)), ) - with self.subTest(tranpose="tn"): + with self.subTest(transpose="tn"): shape_a = (B, K, M) shape_b = (B, K, N) self.__gemm_test( @@ -108,7 +108,7 @@ class TestBlas(mlx_tests.MLXTestCase): f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)), ) - with self.subTest(tranpose="tt"): + with self.subTest(transpose="tt"): shape_a = (B, K, M) shape_b = (B, N, K) self.__gemm_test(