Update pre-commit hooks (#984)

This commit is contained in:
Nripesh Niketan 2024-04-11 18:27:53 +04:00 committed by GitHub
parent 12d4507ee3
commit ffff671273
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 37 additions and 44 deletions

View File

@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v17.0.6
rev: v18.1.3
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@ -17,11 +17,10 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {

View File

@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
: src_ld(params_ -> wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
: src_ld(params_ -> wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral {
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
: src_ld(params_ -> wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),

View File

@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) {
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ".";
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
<< ".";
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
@ -332,8 +331,8 @@ void steel_matmul(
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@ -422,8 +421,8 @@ void steel_matmul(
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@ -903,8 +902,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@ -992,8 +991,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
// Encode and dispatch kernel

View File

@ -63,8 +63,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
std::ostringstream header;
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
<< " 'fortran_order': " << fortran_order << ","
<< " 'shape': (";
<< " 'fortran_order': " << fortran_order << "," << " 'shape': (";
for (auto i : a.shape()) {
header << i << ", ";
}

View File

@ -932,15 +932,15 @@ array pad(
if (low_pad_size[i] < 0) {
std::ostringstream msg;
msg << "Invalid low padding size (" << low_pad_size[i]
<< ") passed to pad"
<< " for axis " << i << ". Padding sizes must be non-negative";
<< ") passed to pad" << " for axis " << i
<< ". Padding sizes must be non-negative";
throw std::invalid_argument(msg.str());
}
if (high_pad_size[i] < 0) {
std::ostringstream msg;
msg << "Invalid high padding size (" << high_pad_size[i]
<< ") passed to pad"
<< " for axis " << i << ". Padding sizes must be non-negative";
<< ") passed to pad" << " for axis " << i
<< ". Padding sizes must be non-negative";
throw std::invalid_argument(msg.str());
}
@ -2508,8 +2508,8 @@ array take_along_axis(
StreamOrDevice s /* = {} */) {
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[take_along_axis] Received invalid axis "
<< " for array with " << a.ndim() << " dimensions.";
msg << "[take_along_axis] Received invalid axis " << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@ -2904,15 +2904,15 @@ inline std::vector<int> conv_out_shape(
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
<< pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive."
<< " Got strides " << strides << ".";
msg << "[conv] Stride sizes must be positive." << " Got strides "
<< strides << ".";
throw std::invalid_argument(msg.str());
}
@ -2948,8 +2948,7 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
if (in.ndim() != n_dim + 2) {
std::ostringstream msg;
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
<< n_dim << "D convolution."
<< " Expected an array with " << n_dim + 2
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
<< " dimensions following the format [N, ..., C_in].";
throw std::invalid_argument(msg.str());
}
@ -3236,8 +3235,7 @@ std::tuple<array, array, array> quantize(
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided "
<< " matrix has shape " << w.shape();
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}

View File

@ -24,8 +24,8 @@ class Synchronizer : public Primitive {
public:
explicit Synchronizer(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
DEFINE_PRINT(Synchronize);
};

View File

@ -18,9 +18,7 @@ TEST_CASE("test simple custom vjp") {
fn,
[&](const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&) {
return std::vector<array>{one, one};
});
const std::vector<array>&) { return std::vector<array>{one, one}; });
auto [z, g] = vjp(fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);