spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
Josh Soref 2024-01-01 22:46:12 -05:00
parent 9557a8fa6f
commit 072f7e0b8c
3 changed files with 7 additions and 7 deletions

View File

@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3) return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose): def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype) np_dtype = getattr(np, dtype)
mlx_gb_s = [] mlx_gb_s = []
mlx_gflops = [] mlx_gflops = []
@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
ax.legend() ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose): def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype) np_dtype = getattr(np, dtype)
mlx_gb_s = [] mlx_gb_s = []
mlx_gflops = [] mlx_gflops = []

View File

@ -517,7 +517,7 @@ array slice(
// Gather moves the axis up, remainder needs to be squeezed // Gather moves the axis up, remainder needs to be squeezed
out_reshape[i] = indices[i].size(); out_reshape[i] = indices[i].size();
// Gather moves the axis up, needs to be tranposed // Gather moves the axis up, needs to be transposed
t_axes[ax] = i; t_axes[ax] = i;
} }

View File

@ -81,12 +81,12 @@ class TestBlas(mlx_tests.MLXTestCase):
for B, M, N, K in shapes: for B, M, N, K in shapes:
with self.subTest(tranpose="nn"): with self.subTest(transpose="nn"):
shape_a = (B, M, K) shape_a = (B, M, K)
shape_b = (B, K, N) shape_b = (B, K, N)
self.__gemm_test(shape_a, shape_b, np_dtype) self.__gemm_test(shape_a, shape_b, np_dtype)
with self.subTest(tranpose="nt"): with self.subTest(transpose="nt"):
shape_a = (B, M, K) shape_a = (B, M, K)
shape_b = (B, N, K) shape_b = (B, N, K)
self.__gemm_test( self.__gemm_test(
@ -97,7 +97,7 @@ class TestBlas(mlx_tests.MLXTestCase):
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)), f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
) )
with self.subTest(tranpose="tn"): with self.subTest(transpose="tn"):
shape_a = (B, K, M) shape_a = (B, K, M)
shape_b = (B, K, N) shape_b = (B, K, N)
self.__gemm_test( self.__gemm_test(
@ -108,7 +108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)), f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
) )
with self.subTest(tranpose="tt"): with self.subTest(transpose="tt"):
shape_a = (B, K, M) shape_a = (B, K, M)
shape_b = (B, N, K) shape_b = (B, N, K)
self.__gemm_test( self.__gemm_test(