From ffff67127341f466f1c74eaee71096b61ff6441b Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Thu, 11 Apr 2024 18:27:53 +0400 Subject: [PATCH] Update pre-commit hooks (#984) --- .pre-commit-config.yaml | 4 +-- benchmarks/cpp/time_utils.h | 9 +++---- .../steel/conv/loaders/loader_channel_l.h | 2 +- .../steel/conv/loaders/loader_channel_n.h | 2 +- .../steel/conv/loaders/loader_general.h | 2 +- mlx/backend/metal/matmul.cpp | 25 +++++++++--------- mlx/io/load.cpp | 3 +-- mlx/ops.cpp | 26 +++++++++---------- mlx/transforms.cpp | 4 +-- tests/custom_vjp_tests.cpp | 4 +-- 10 files changed, 37 insertions(+), 44 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd5ebec30..ae9db3839 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v17.0.6 + rev: v18.1.3 hooks: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.2.0 + rev: 24.3.0 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/benchmarks/cpp/time_utils.h b/benchmarks/cpp/time_utils.h index 09ba6c173..cf9d21a16 100644 --- a/benchmarks/cpp/time_utils.h +++ b/benchmarks/cpp/time_utils.h @@ -17,11 +17,10 @@ << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ << std::endl; -#define TIMEM(MSG, FUNC, ...) \ - std::cout << "Timing " \ - << "(" << MSG << ") " << #FUNC << " ... " << std::flush \ - << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ - << std::endl; +#define TIMEM(MSG, FUNC, ...) \ + std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \ + << std::flush << std::setprecision(5) \ + << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl; template double time_fn(F fn, Args&&... args) { diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h index dad496e81..e355df3cd 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader { const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), + : src_ld(params_ -> wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index 56027916e..1b947fced 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), + : src_ld(params_ -> wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h index 3e396c2af..b93fd927a 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral { const short base_ww_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), + : src_ld(params_ -> wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 0ea89e51b..856f4d67e 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) { std::vector B_bshape{b.shape().begin(), b.shape().end() - 2}; if (A_bshape != B_bshape) { std::ostringstream msg; - msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " - << "A " << a.shape() << ", B " << b.shape() << "."; + msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " + << a.shape() << ", B " << b.shape() << "."; throw std::runtime_error(msg.str()); } @@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { std::vector C_bshape{c.shape().begin(), c.shape().end() - 2}; if (A_bshape != B_bshape || A_bshape != C_bshape) { std::ostringstream msg; - msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " - << "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape() - << "."; + msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " + << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; throw std::runtime_error(msg.str()); } @@ -332,8 +331,8 @@ void steel_matmul( << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" - << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch gemm kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -422,8 +421,8 @@ void steel_matmul( << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" - << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -903,8 +902,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" - << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch gemm kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -992,8 +991,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" - << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned" << ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby"); // Encode and dispatch kernel diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index b1e73ce37..294a1229f 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -63,8 +63,7 @@ void save(std::shared_ptr out_stream, array a) { std::string fortran_order = a.flags().col_contiguous ? "True" : "False"; std::ostringstream header; header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "'," - << " 'fortran_order': " << fortran_order << "," - << " 'shape': ("; + << " 'fortran_order': " << fortran_order << "," << " 'shape': ("; for (auto i : a.shape()) { header << i << ", "; } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ceb6b291d..80cd40127 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -932,15 +932,15 @@ array pad( if (low_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid low padding size (" << low_pad_size[i] - << ") passed to pad" - << " for axis " << i << ". Padding sizes must be non-negative"; + << ") passed to pad" << " for axis " << i + << ". Padding sizes must be non-negative"; throw std::invalid_argument(msg.str()); } if (high_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid high padding size (" << high_pad_size[i] - << ") passed to pad" - << " for axis " << i << ". Padding sizes must be non-negative"; + << ") passed to pad" << " for axis " << i + << ". Padding sizes must be non-negative"; throw std::invalid_argument(msg.str()); } @@ -2508,8 +2508,8 @@ array take_along_axis( StreamOrDevice s /* = {} */) { if (axis + a.ndim() < 0 || axis >= static_cast(a.ndim())) { std::ostringstream msg; - msg << "[take_along_axis] Received invalid axis " - << " for array with " << a.ndim() << " dimensions."; + msg << "[take_along_axis] Received invalid axis " << " for array with " + << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -2904,15 +2904,15 @@ inline std::vector conv_out_shape( if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) { std::ostringstream msg; - msg << "[conv] Padding sizes must be non-negative." - << " Got padding " << pads_lo << " | " << pads_hi << "."; + msg << "[conv] Padding sizes must be non-negative." << " Got padding " + << pads_lo << " | " << pads_hi << "."; throw std::invalid_argument(msg.str()); } if (strides[i - 1] <= 0) { std::ostringstream msg; - msg << "[conv] Stride sizes must be positive." - << " Got strides " << strides << "."; + msg << "[conv] Stride sizes must be positive." << " Got strides " + << strides << "."; throw std::invalid_argument(msg.str()); } @@ -2948,8 +2948,7 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) { if (in.ndim() != n_dim + 2) { std::ostringstream msg; msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for " - << n_dim << "D convolution." - << " Expected an array with " << n_dim + 2 + << n_dim << "D convolution." << " Expected an array with " << n_dim + 2 << " dimensions following the format [N, ..., C_in]."; throw std::invalid_argument(msg.str()); } @@ -3236,8 +3235,7 @@ std::tuple quantize( std::ostringstream msg; msg << "[quantize] The last dimension of the matrix needs to be divisible by " << "the quantization group size " << group_size - << ". However the provided " - << " matrix has shape " << w.shape(); + << ". However the provided " << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index a19bb8c30..f6ab6f747 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -24,8 +24,8 @@ class Synchronizer : public Primitive { public: explicit Synchronizer(Stream stream) : Primitive(stream){}; - void eval_cpu(const std::vector&, std::vector&) override{}; - void eval_gpu(const std::vector&, std::vector&) override{}; + void eval_cpu(const std::vector&, std::vector&) override {}; + void eval_gpu(const std::vector&, std::vector&) override {}; DEFINE_PRINT(Synchronize); }; diff --git a/tests/custom_vjp_tests.cpp b/tests/custom_vjp_tests.cpp index f916b694b..e0b4029be 100644 --- a/tests/custom_vjp_tests.cpp +++ b/tests/custom_vjp_tests.cpp @@ -18,9 +18,7 @@ TEST_CASE("test simple custom vjp") { fn, [&](const std::vector&, const std::vector&, - const std::vector&) { - return std::vector{one, one}; - }); + const std::vector&) { return std::vector{one, one}; }); auto [z, g] = vjp(fn, {x, y}, {one, one}); CHECK_EQ(z[0].item(), 6.0f);