From 45a8b226af38f9d334a8b887c929168856212d2c Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 30 Oct 2025 16:24:51 -0700 Subject: [PATCH] WIP (cpu) --- mlx/backend/cpu/arange.h | 2 +- mlx/backend/cpu/binary.cpp | 7 +++++- mlx/backend/cpu/cholesky.cpp | 4 ++-- mlx/backend/cpu/gemm.h | 8 +++---- mlx/backend/cpu/gemms/bnns.cpp | 22 ++++++++--------- mlx/backend/cpu/gemms/cblas.cpp | 42 ++++++++++++++++----------------- mlx/backend/cpu/hadamard.cpp | 10 ++++---- mlx/backend/cpu/indexing.cpp | 30 +++++++++++------------ mlx/backend/cpu/masked_mm.cpp | 28 +++++++++++----------- mlx/backend/cpu/primitives.cpp | 22 ++++++++--------- mlx/backend/cpu/qrf.cpp | 10 ++++---- mlx/backend/cpu/simd/math.h | 7 +++--- mlx/backend/cpu/sort.cpp | 20 ++++++++-------- mlx/backend/cpu/svd.cpp | 12 +++++----- mlx/backend/cpu/ternary.h | 2 +- mlx/backend/cpu/unary.h | 10 ++++---- 16 files changed, 121 insertions(+), 115 deletions(-) diff --git a/mlx/backend/cpu/arange.h b/mlx/backend/cpu/arange.h index 9e9b03bd7..96b1e6e04 100644 --- a/mlx/backend/cpu/arange.h +++ b/mlx/backend/cpu/arange.h @@ -10,7 +10,7 @@ namespace mlx::core { namespace { template -void arange(T start, T next, array& out, size_t size, Stream stream) { +void arange(T start, T next, array& out, int64_t size, Stream stream) { auto ptr = out.data(); auto step_size = next - start; auto& encoder = cpu::get_command_encoder(stream); diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index 94dac1435..d98b7332b 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -17,7 +17,12 @@ namespace mlx::core { namespace { template -void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) { +void binary( + const array& a, + const array& b, + array& out, + Op /* op */, + Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); diff --git a/mlx/backend/cpu/cholesky.cpp b/mlx/backend/cpu/cholesky.cpp index 3c5bbbc93..244642340 100644 --- a/mlx/backend/cpu/cholesky.cpp +++ b/mlx/backend/cpu/cholesky.cpp @@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) { N = a.shape(-1), size = a.size()]() mutable { char uplo = (upper) ? 'L' : 'U'; - size_t num_matrices = size / (N * N); - for (int i = 0; i < num_matrices; i++) { + int64_t num_matrices = size / (N * N); + for (int64_t i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. int info; potrf( diff --git a/mlx/backend/cpu/gemm.h b/mlx/backend/cpu/gemm.h index d665cb91f..93aa3ee31 100644 --- a/mlx/backend/cpu/gemm.h +++ b/mlx/backend/cpu/gemm.h @@ -12,12 +12,12 @@ void matmul( T* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, diff --git a/mlx/backend/cpu/gemms/bnns.cpp b/mlx/backend/cpu/gemms/bnns.cpp index 2ec0fd4e2..545d79d4b 100644 --- a/mlx/backend/cpu/gemms/bnns.cpp +++ b/mlx/backend/cpu/gemms/bnns.cpp @@ -34,7 +34,7 @@ void matmul_bnns( bool b_transposed, size_t lda, size_t ldb, - size_t ldc, + size_t /* ldc */, float alpha, float beta, size_t batch_size, @@ -52,7 +52,7 @@ void matmul_bnns( #pragma GCC diagnostic ignored "-Wdeprecated-declarations" if (beta != 1.0 && beta != 0.0) { // scale the output - for (auto i = 0; i < batch_size * M * N; ++i) { + for (size_t i = 0; i < batch_size * M * N; ++i) { out[i] *= beta; } beta = 1.0; @@ -127,7 +127,7 @@ void matmul_bnns( auto bnns_filter = BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); - for (int i = 0; i < batch_size; ++i) { + for (size_t i = 0; i < batch_size; ++i) { BNNSFilterApplyTwoInput( bnns_filter, reinterpret_cast( @@ -148,12 +148,12 @@ void matmul( float16_t* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, @@ -183,12 +183,12 @@ void matmul( bfloat16_t* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, diff --git a/mlx/backend/cpu/gemms/cblas.cpp b/mlx/backend/cpu/gemms/cblas.cpp index 765e9f539..3277b7a78 100644 --- a/mlx/backend/cpu/gemms/cblas.cpp +++ b/mlx/backend/cpu/gemms/cblas.cpp @@ -13,20 +13,20 @@ void matmul( float* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); - size_t M = a_shape[ndim - 2]; - size_t N = b_shape[ndim - 1]; - size_t K = a_shape[ndim - 1]; + int64_t M = a_shape[ndim - 2]; + int64_t N = b_shape[ndim - 1]; + int64_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { cblas_sgemm( @@ -54,20 +54,20 @@ void matmul( double* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); - size_t M = a_shape[ndim - 2]; - size_t N = b_shape[ndim - 1]; - size_t K = a_shape[ndim - 1]; + int64_t M = a_shape[ndim - 2]; + int64_t N = b_shape[ndim - 1]; + int64_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { cblas_dgemm( @@ -95,20 +95,20 @@ void matmul( complex64_t* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); - size_t M = a_shape[ndim - 2]; - size_t N = b_shape[ndim - 1]; - size_t K = a_shape[ndim - 1]; + int64_t M = a_shape[ndim - 2]; + int64_t N = b_shape[ndim - 1]; + int64_t K = a_shape[ndim - 1]; auto calpha = static_cast(alpha); auto cbeta = static_cast(beta); diff --git a/mlx/backend/cpu/hadamard.cpp b/mlx/backend/cpu/hadamard.cpp index bf7e1dc26..aa1b164bb 100644 --- a/mlx/backend/cpu/hadamard.cpp +++ b/mlx/backend/cpu/hadamard.cpp @@ -11,9 +11,9 @@ namespace mlx::core { // n = 2^k component template -void hadamard_n(T* out, int n, int m, float scale, size_t size) { +void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) { for (int b = 0; b < size / n; b++) { - size_t loc = b * n; + int64_t loc = b * n; T* data_ptr = out + loc; int h = 1; int n_over_2 = n / 2; @@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) { // m component template -void hadamard_m(T* out, int n, int m, float scale, size_t size) { +void hadamard_m(T* out, int n, int m, float scale, int64_t size) { auto h_matrices = hadamard_matrices(); auto& matrix = h_matrices[m]; auto start = 1; @@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) { std::vector hmat_vec; while (end != std::string_view::npos) { auto row = matrix.substr(start, end - start); - for (int i = 0; i < row.length(); i++) { + for (int i = 0; i < std::ssize(row); i++) { hmat_vec.push_back(row[i] == '+'); } start = end + 1; @@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) { } for (int b = 0; b < size / m / n; b++) { - size_t loc = b * n * m; + int64_t loc = b * n * m; T* data_ptr = out + loc; for (int i = 0; i < n; i++) { std::vector out(m); diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 6daced6fa..b743550fd 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -78,7 +78,7 @@ void gather( can_copy = true; // Ignore leading 1s - int i = 0; + int64_t i = 0; for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) ; @@ -91,7 +91,7 @@ void gather( can_copy = true; // Ignore trailing 1s - int i = slice_sizes.size() - 1; + int64_t i = slice_sizes.size() - 1; for (; i >= 0 && slice_sizes[i] == 1; --i) ; @@ -101,11 +101,11 @@ void gather( can_copy = (src.shape(i) == slice_sizes[i]); } } - size_t slice_size = 1; + int64_t slice_size = 1; for (auto s : slice_sizes) { slice_size *= s; } - size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; + int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; const T* src_ptr = src.data(); T* dst_ptr = out.data(); @@ -115,10 +115,10 @@ void gather( src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } - size_t out_idx = 0; - for (int idx = 0; idx < ind_size; idx++) { - size_t src_idx = 0; - for (int ii = 0; ii < inds.size(); ++ii) { + int64_t out_idx = 0; + for (int64_t idx = 0; idx < ind_size; idx++) { + int64_t src_idx = 0; + for (int ii = 0; ii < std::ssize(inds); ++ii) { auto ax = axes[ii]; auto idx_loc = its[ii].loc; its[ii].step(); @@ -134,7 +134,7 @@ void gather( src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); out_idx += slice_size; } else { - for (int jj = 0; jj < slice_size; jj++) { + for (int64_t jj = 0; jj < slice_size; jj++) { dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; src_it.step(); } @@ -403,11 +403,11 @@ void scatter( const std::vector& axes) { int nind = inds.size(); auto inds_ndim = updates.ndim() - out.ndim(); - size_t n_updates = nind ? inds[0].size() : 1; + int64_t n_updates = nind ? inds[0].size() : 1; Shape update_shape( updates.shape().begin() + inds_ndim, updates.shape().end()); - size_t update_size = 1; + int64_t update_size = 1; for (auto us : update_shape) { update_size *= us; } @@ -418,9 +418,9 @@ void scatter( auto out_ptr = out.data(); auto upd_ptr = updates.data(); - for (int i = 0; i < n_updates; ++i) { - size_t out_offset = 0; - for (int j = 0; j < inds.size(); ++j) { + for (int64_t i = 0; i < n_updates; ++i) { + int64_t out_offset = 0; + for (int j = 0; j < std::ssize(inds); ++j) { auto ax = axes[j]; auto idx_loc = its[j].loc; its[j].step(); @@ -429,7 +429,7 @@ void scatter( out_offset += (idx_val * out.strides()[ax]); } update_it.seek(i * update_size); - for (int j = 0; j < update_size; ++j) { + for (int64_t j = 0; j < update_size; ++j) { OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); update_it.step(); out_it.step(); diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 688479c60..81012a84d 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -25,7 +25,7 @@ inline void mask_matrix( const int64_t Y_data_str, const int64_t X_mask_str, const int64_t Y_mask_str, - const size_t mask_offset) { + const int64_t mask_offset) { int tX = (X + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size; @@ -61,13 +61,13 @@ inline void segmented_mm( T* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, + int64_t lda, + int64_t ldb, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides, - size_t num_segments, + int64_t num_segments, const Shape& segments_shape, const Strides& segments_strides) { int ndim = a_shape.size(); @@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { auto [b_transposed, ldb, b, b_copied] = check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + int64_t M = a.shape(-2); + int64_t N = b.shape(-1); + int64_t K = a.shape(-1); if (M == 0 || N == 0) { return; @@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { int batch_idx, int X, int Y, - size_t X_data_str, - size_t Y_data_str, + int64_t X_data_str, + int64_t Y_data_str, const Shape& mask_shape, const Strides& mask_strides, bool is_bool) { @@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_ptr = out.data(); - size_t num_matrices = out.size() / (M * size_t(N)); + int64_t num_matrices = out.size() / (M * int64_t(N)); auto ldc = out.shape(-1); encoder.dispatch([a_ptr, @@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + int64_t M = a.shape(-2); + int64_t N = b.shape(-1); + int64_t K = a.shape(-1); if (M == 0 || N == 0) { return; @@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { // Get batch dims auto batch_size_out = out.size() / (M * N); - size_t matrix_stride_out = M * N; + int64_t matrix_stride_out = M * N; auto get_batch_dims = [](const auto& v) { return decltype(v){v.begin(), v.end() - 2}; diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f2cb12fdd..18db2b3dd 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -48,7 +48,7 @@ static std::pair compute_dynamic_offset( auto compute_offset = [strides, axes, offset = offset.data()](const auto* indices) { int64_t offset_ = 0; - for (int i = 0; i < axes.size(); ++i) { + for (int i = 0; i < std::ssize(axes); ++i) { offset_ += indices[i] * strides[axes[i]]; } offset[0] = offset_; @@ -193,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); - size_t data_offset = strides[axis_] * sizes[i]; + int64_t data_offset = strides[axis_] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); @@ -205,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; + constexpr int64_t extra_bytes = 16384; if (in.buffer_size() <= out.nbytes() + extra_bytes && (in.flags().row_contiguous || (allow_col_major_ && in.flags().col_contiguous))) { @@ -254,8 +254,8 @@ void Pad::eval_cpu(const std::vector& inputs, array& out) { copy_cpu(val, out, CopyType::Scalar, stream()); // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes_.size(); i++) { + int64_t data_offset = 0; + for (int i = 0; i < std::ssize(axes_); i++) { auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; data_offset += out.strides()[ax] * low_pad_size_[i]; } @@ -274,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) auto& keys = inputs[0]; - size_t num_keys = keys.size() / 2; + int64_t num_keys = keys.size() / 2; - size_t elems_per_key = out.size() / num_keys; - size_t bytes_per_key = out.itemsize() * elems_per_key; + int64_t elems_per_key = out.size() / num_keys; + int64_t bytes_per_key = out.itemsize() * elems_per_key; out.set_data(allocator::malloc(out.nbytes())); auto kptr = inputs[0].data(); @@ -291,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { num_keys, kshape = keys.shape(), kstrides = keys.strides()]() mutable { - size_t out_skip = (bytes_per_key + 4 - 1) / 4; - auto half_size = out_skip / 2; + int64_t out_skip = (bytes_per_key + 4 - 1) / 4; + uintptr_t half_size = out_skip / 2; bool even = out_skip % 2 == 0; for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { auto ptr = reinterpret_cast(cptr); diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index 13c7e1132..d3d6717e8 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { const int M = a.shape(-2); const int N = a.shape(-1); const int lda = M; - size_t num_matrices = a.size() / (M * N); + int64_t num_matrices = a.size() / (M * N); // Copy A to inplace input and make it col-contiguous array in(a.shape(), a.dtype(), nullptr, {}); @@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { auto work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { // Solve geqrf( &M, @@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { } allocator::free(work); - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { /// num_reflectors x N for (int j = 0; j < num_reflectors; ++j) { for (int k = 0; k < j; ++k) { @@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { // Compute Q orgqr( &M, @@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { &info); } - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { // M x num_reflectors for (int j = 0; j < M; ++j) { for (int k = 0; k < num_reflectors; ++k) { diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index f9fc8317a..9854f7e91 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -79,7 +79,8 @@ Simd sincos(Simd in) { // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 // and another one for Pi/4(2)) != static_cast(0); // The magic pass: "Extended precision modular arithmetic" // x = ((x - y * DP1) - y * DP2) - y * DP3 @@ -87,8 +88,8 @@ Simd sincos(Simd in) { x = fma(y, Simd(-2.4187564849853515625e-4f), x); x = fma(y, Simd(-3.77489497744594108e-8f), x); - sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); - auto sign_mask_cos = ((emm2 - 2) & 4) != 0; + sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast(0)); + auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast(0); // Evaluate the first polynom (0 <= x <= Pi/4) in y1, // and the second polynom (Pi/4 <= x <= 0) in y2 diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index fcf12d7ad..8e05951aa 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -120,8 +120,8 @@ template void sort(array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + out.ndim() : axis; - size_t in_size = out.size(); - size_t n_rows = in_size / out.shape(axis); + int64_t in_size = out.size(); + int64_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -136,7 +136,7 @@ void sort(array& out, int axis) { ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { T* data_ptr = out_ptr + src_it.loc; StridedIterator st(data_ptr, axis_stride, 0); @@ -151,7 +151,7 @@ template void argsort(const array& in, array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; - size_t n_rows = in.size() / in.shape(axis); + int64_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); @@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) { out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); auto in_ptr = in.data(); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { const T* data_ptr = in_ptr + in_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc; @@ -214,8 +214,8 @@ template void partition(array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + out.ndim() : axis; - size_t in_size = out.size(); - size_t n_rows = in_size / out.shape(axis); + int64_t in_size = out.size(); + int64_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) { ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { T* data_ptr = out_ptr + src_it.loc; src_it.step(); @@ -248,7 +248,7 @@ template void argpartition(const array& in, array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; - size_t n_rows = in.size() / in.shape(axis); + int64_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); @@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) { auto in_ptr = in.data(); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { const T* data_ptr = in_ptr + in_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc; in_it.step(); diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 1fc94c382..54d15fabc 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -27,7 +27,7 @@ void svd_impl( const int N = a.shape(-1); const int K = std::min(M, N); - size_t num_matrices = a.size() / (M * N); + int64_t num_matrices = a.size() / (M * N); // lapack clobbers the input, so we have to make a copy. array in(a.shape(), a.dtype(), nullptr, {}); @@ -121,7 +121,7 @@ void svd_impl( auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Loop over matrices. - for (int i = 0; i < num_matrices; i++) { + for (int64_t i = 0; i < num_matrices; i++) { gesdd( /* jobz = */ jobz, // M and N are swapped since lapack expects column-major. @@ -153,10 +153,10 @@ void svd_impl( template void compute_svd( - const array& a, - bool compute_uv, - std::vector& outputs, - Stream stream) {} + const array& /* a */, + bool /* compute_uv */, + std::vector& /* outputs */, + Stream /* stream */) {} void SVD::eval_cpu( const std::vector& inputs, diff --git a/mlx/backend/cpu/ternary.h b/mlx/backend/cpu/ternary.h index a27a7f2a9..4674d9fef 100644 --- a/mlx/backend/cpu/ternary.h +++ b/mlx/backend/cpu/ternary.h @@ -136,7 +136,7 @@ void ternary_op( if (topt == TernaryOpType::ScalarScalarScalar) { *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); } else if (topt == TernaryOpType::VectorVectorVector) { - for (size_t i = 0; i < out.size(); ++i) { + for (int64_t i = 0; i < out.size(); ++i) { *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); a_ptr++; b_ptr++; diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index 14c1dd479..8a4c64e69 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -10,8 +10,8 @@ namespace mlx::core { template -void unary_op(const T* a, U* out, size_t shape, size_t stride) { - for (size_t i = 0; i < shape; i += 1) { +void unary_op(const T* a, U* out, int64_t shape, int64_t stride) { + for (int64_t i = 0; i < shape; i += 1) { out[i] = Op{}(*a); a += stride; } @@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) { src++; } } else { - size_t shape = ndim > 0 ? a.shape().back() : 1; - size_t stride = ndim > 0 ? a.strides().back() : 1; + int64_t shape = ndim > 0 ? a.shape().back() : 1; + int64_t stride = ndim > 0 ? a.strides().back() : 1; if (ndim <= 1) { unary_op(src, dst, shape, stride); return; } auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); - for (size_t elem = 0; elem < a.size(); elem += shape) { + for (int64_t elem = 0; elem < a.size(); elem += shape) { unary_op(src + it.loc, dst + elem, shape, stride); it.step(); }