mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Update pre-commit hooks (#984)
This commit is contained in:
parent
12d4507ee3
commit
ffff671273
@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v17.0.6
|
||||
rev: v18.1.3
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.2.0
|
||||
rev: 24.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@ -17,11 +17,10 @@
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||
<< std::flush << std::setprecision(5) \
|
||||
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args&&... args) {
|
||||
|
@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ".";
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
|
||||
<< ".";
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
@ -332,8 +331,8 @@ void steel_matmul(
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -422,8 +421,8 @@ void steel_matmul(
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -903,8 +902,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -992,8 +991,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
|
||||
|
||||
// Encode and dispatch kernel
|
||||
|
@ -63,8 +63,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
|
||||
std::ostringstream header;
|
||||
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
|
||||
<< " 'fortran_order': " << fortran_order << ","
|
||||
<< " 'shape': (";
|
||||
<< " 'fortran_order': " << fortran_order << "," << " 'shape': (";
|
||||
for (auto i : a.shape()) {
|
||||
header << i << ", ";
|
||||
}
|
||||
|
26
mlx/ops.cpp
26
mlx/ops.cpp
@ -932,15 +932,15 @@ array pad(
|
||||
if (low_pad_size[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid low padding size (" << low_pad_size[i]
|
||||
<< ") passed to pad"
|
||||
<< " for axis " << i << ". Padding sizes must be non-negative";
|
||||
<< ") passed to pad" << " for axis " << i
|
||||
<< ". Padding sizes must be non-negative";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (high_pad_size[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid high padding size (" << high_pad_size[i]
|
||||
<< ") passed to pad"
|
||||
<< " for axis " << i << ". Padding sizes must be non-negative";
|
||||
<< ") passed to pad" << " for axis " << i
|
||||
<< ". Padding sizes must be non-negative";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -2508,8 +2508,8 @@ array take_along_axis(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[take_along_axis] Received invalid axis "
|
||||
<< " for array with " << a.ndim() << " dimensions.";
|
||||
msg << "[take_along_axis] Received invalid axis " << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -2904,15 +2904,15 @@ inline std::vector<int> conv_out_shape(
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative."
|
||||
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
|
||||
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
|
||||
<< pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (strides[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Stride sizes must be positive."
|
||||
<< " Got strides " << strides << ".";
|
||||
msg << "[conv] Stride sizes must be positive." << " Got strides "
|
||||
<< strides << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -2948,8 +2948,7 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
if (in.ndim() != n_dim + 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
|
||||
<< n_dim << "D convolution."
|
||||
<< " Expected an array with " << n_dim + 2
|
||||
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
|
||||
<< " dimensions following the format [N, ..., C_in].";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -3236,8 +3235,7 @@ std::tuple<array, array, array> quantize(
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided "
|
||||
<< " matrix has shape " << w.shape();
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
@ -24,8 +24,8 @@ class Synchronizer : public Primitive {
|
||||
public:
|
||||
explicit Synchronizer(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
@ -18,9 +18,7 @@ TEST_CASE("test simple custom vjp") {
|
||||
fn,
|
||||
[&](const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&) {
|
||||
return std::vector<array>{one, one};
|
||||
});
|
||||
const std::vector<array>&) { return std::vector<array>{one, one}; });
|
||||
|
||||
auto [z, g] = vjp(fn, {x, y}, {one, one});
|
||||
CHECK_EQ(z[0].item<float>(), 6.0f);
|
||||
|
Loading…
Reference in New Issue
Block a user