diff --git a/mlx/array.cpp b/mlx/array.cpp index d8d12e0db..a05e8dfa7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -241,8 +241,8 @@ array::ArrayDesc::ArrayDesc( std::vector inputs) : shape(std::move(shape)), dtype(dtype), - status(Status::unscheduled), primitive(std::move(primitive)), + status(Status::unscheduled), inputs(std::move(inputs)) { init(); } diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index 5684f9709..5d4638ade 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu( encoder.add_temporaries(std::move(temps)); } -void explicit_gemm_conv_2D_cpu( - const array& in, - const array& wt, - array out, - const std::vector& padding_lo, - const std::vector& padding_hi, - const std::vector& wt_strides, - const std::vector& wt_dilation, - Stream stream) { - const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const int iH = in.shape(1); // Input spatial dim - const int iW = in.shape(2); // Input spatial dim - const int oH = out.shape(1); // Output spatial dim - const int oW = out.shape(2); // Output spatial dim - const int O = wt.shape(0); // Out channels - const int C = wt.shape(3); // In channels - const int wH = wt.shape(1); // Weight spatial dim - const int wW = wt.shape(2); // Weight spatial dim - - auto conv_dtype = out.dtype(); - auto& encoder = cpu::get_command_encoder(stream); - - // Pad input - Shape padded_shape = { - N, - iH + padding_lo[0] + padding_hi[0], - iW + padding_lo[1] + padding_hi[1], - C}; - array in_padded(padded_shape, conv_dtype, nullptr, {}); - - // Fill with zeros - std::vector temps; - temps.push_back(array(0, conv_dtype)); - copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream); - - // Pick input slice from padded - size_t data_offset = padding_lo[0] * in_padded.strides()[1] + - padding_lo[1] * in_padded.strides()[2]; - array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); - in_padded_slice.copy_shared_buffer( - in_padded, - in_padded.strides(), - in_padded.flags(), - in_padded_slice.size(), - data_offset); - temps.push_back(in_padded_slice); - - // Copy input values into the slice - copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream); - - // Make strided view - Shape strided_shape = {N, oH, oW, wH, wW, C}; - - Strides strided_strides = { - in_padded.strides()[0], - in_padded.strides()[1] * wt_strides[0], - in_padded.strides()[2] * wt_strides[1], - in_padded.strides()[1], - in_padded.strides()[2], - in_padded.strides()[3]}; - auto flags = in_padded.flags(); - - array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); - in_strided_view.copy_shared_buffer( - in_padded, strided_strides, flags, in_strided_view.size(), 0); - - // Materialize strided view - Shape strided_reshape = {N * oH * oW, wH * wW * C}; - array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy_cpu(in_strided_view, in_strided, CopyType::General, stream); - temps.push_back(in_strided); - - // Check wt dtype and prepare - auto gemm_wt = wt; - auto gemm_out = out; - - if (wt.dtype() != float32 || !wt.flags().row_contiguous) { - auto ctype = - wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; - gemm_wt = array(wt.shape(), float32, nullptr, {}); - copy_cpu(wt, gemm_wt, ctype, stream); - temps.push_back(gemm_wt); - } - - if (out.dtype() != float32) { - gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); - temps.push_back(gemm_out); - } - - encoder.set_input_array(in_strided); - encoder.set_input_array(gemm_wt); - encoder.set_output_array(gemm_out); - - encoder.dispatch([in_strided_ptr = in_strided.data(), - gemm_wt_ptr = gemm_wt.data(), - gemm_out_ptr = gemm_out.data(), - strided_reshape = std::move(strided_reshape), - O]() { - // Perform gemm - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, // no trans A - CblasTrans, // transB - strided_reshape[0], // M - O, // N - strided_reshape[1], // K - 1.0f, // alpha - in_strided_ptr, - strided_reshape[1], // lda - gemm_wt_ptr, - strided_reshape[1], // ldb - 0.0f, // beta - gemm_out_ptr, - O // ldc - ); - }); - - // Copy results if needed - if (out.dtype() != float32) { - copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream); - } - encoder.add_temporaries(std::move(temps)); -} - void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp index a01295145..0d1f95a57 100644 --- a/mlx/backend/cpu/eig.cpp +++ b/mlx/backend/cpu/eig.cpp @@ -46,7 +46,6 @@ void eig_impl( int info; { T work; - int iwork; geev( &jobl, &jobr, diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index c3efb79cd..688479c60 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { encoder.set_input_array(a); encoder.set_input_array(b); - const void* a_mask_ptr; - const void* b_mask_ptr; - const void* out_mask_ptr; + const void* a_mask_ptr = nullptr; + const void* b_mask_ptr = nullptr; + const void* out_mask_ptr = nullptr; Shape a_mask_shape; Shape b_mask_shape; Shape out_mask_shape; Strides a_mask_strides; Strides b_mask_strides; Strides out_mask_strides; - bool a_mask_bool; - bool b_mask_bool; - bool out_mask_bool; + bool a_mask_bool = false; + bool b_mask_bool = false; + bool out_mask_bool = false; if (has_op_mask) { auto& a_mask = inputs[inputs.size() - 2]; auto& b_mask = inputs[inputs.size() - 1]; @@ -423,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { auto& rhs_indices = inputs[3]; auto batch_shape = get_batch_dims(out.shape()); - int batch_ndim = batch_shape.size(); auto batch_shape_A = get_batch_dims(a.shape()); auto batch_strides_A = get_batch_dims(a.strides()); diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 777b31c02..b346e84db 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -91,7 +91,6 @@ void matmul_general( auto [b_transposed, ldb, b] = check_transpose(b_pre); size_t M = a.shape(-2); size_t N = b.shape(-1); - size_t K = a.shape(-1); if (M == 0 || N == 0) { return; } diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index a475131f7..de50cdb81 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -445,7 +445,6 @@ void mxfp4_qmm( int K) { constexpr int group_size = 32; constexpr int pack_factor = get_pack_factor(4, 8); - constexpr int bytes_per_pack = get_bytes_per_pack(4); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { @@ -487,7 +486,6 @@ void mxfp4_qmm_t( int K) { constexpr int group_size = 32; constexpr int pack_factor = get_pack_factor(4, 8); - constexpr int bytes_per_pack = get_bytes_per_pack(4); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index 56e7b939c..fcf12d7ad 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -39,7 +39,7 @@ struct StridedIterator { StridedIterator() = default; explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0) - : ptr_(ptr + offset * stride), stride_(stride) {} + : stride_(stride), ptr_(ptr + offset * stride) {} explicit StridedIterator(array& arr, int axis, difference_type offset = 0) : StridedIterator(arr.data(), arr.strides()[axis], offset) {} diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 6e57eb401..1fc94c382 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -83,8 +83,6 @@ void svd_impl( auto jobz = (u_ptr) ? "A" : "N"; - // Will contain the number of singular values after the call has returned. - int ns = 0; T workspace_dimension = 0; // Will contain the indices of eigenvectors that failed to converge (not diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 8eb70bcbc..fec940b24 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -32,7 +32,6 @@ namespace metal { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), - residency_set_(device_), buffer_cache_( vm_page_size, [](MTL::Buffer* buf) { return buf->length(); }, @@ -41,7 +40,8 @@ MetalAllocator::MetalAllocator() residency_set_.erase(buf); } buf->release(); - }) { + }), + residency_set_(device_) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 691317916..216735ad3 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -65,7 +65,6 @@ class MetalAllocator : public allocator::Allocator { size_t peak_memory_{0}; size_t max_pool_size_; size_t wired_limit_{0}; - bool relaxed_{true}; size_t num_resources_{0}; size_t resource_limit_{0}; diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 41e399ce3..c48b93c91 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -327,6 +327,10 @@ CustomKernelFunction metal_kernel( void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { + // silence some warnings + (void)is_precompiled_; + (void)shared_memory_; + auto& s = stream(); std::vector copies; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 5f2bb73d3..bd9d16b15 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -728,7 +728,7 @@ MTL::ComputePipelineState* Device::get_kernel_( mtl_linked_funcs->release(); // Add kernel to cache - auto inserted = kernel_map_.insert({hash_name, kernel}); + kernel_map_.insert({hash_name, kernel}); return kernel; } diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 49783200a..100058925 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -71,7 +71,7 @@ void eval(array& arr) { d.get_command_buffer(s.index); } else { command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); } @@ -82,7 +82,7 @@ void finalize(Stream s) { auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); d.end_encoding(s.index); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); d.commit_command_buffer(s.index); d.get_command_buffer(s.index); } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 1e23160a6..d329a4685 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -150,7 +150,6 @@ FFTPlan plan_fft(int n) { } // See if we can use Rader's algorithm to Stockham decompose n - 1 auto rader_factors = prime_factors(factor - 1); - int last_factor = -1; for (int rf : rader_factors) { // We don't nest Rader's algorithm so if `factor - 1` // isn't Stockham decomposable we give up and do Bluestein's. @@ -313,8 +312,6 @@ std::pair compute_bluestein_constants(int n, int bluestein_n) { // w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2)) // w_q = np.fft.fft(1/w_k) // return w_k, w_q - int length = 2 * n - 1; - std::vector> w_k_vec(n); std::vector> w_q_vec(bluestein_n, 0); @@ -484,8 +481,6 @@ void four_step_fft( std::vector& copies, const Stream& s, bool in_place) { - auto& d = metal::device(s.device); - if (plan.bluestein_n == -1) { // Fast no transpose implementation for powers of 2. FourStepParams four_step_params = { @@ -786,7 +781,6 @@ void nd_fft_op( // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. bool step_real = (real && index == axes.size() - 1); - auto step_shape = inverse ? out.shape(axis) : in.shape(axis); const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; array& out_arr = i == 0 ? out : temp_arrs[i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d9497a665..d6bee651d 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -354,8 +354,6 @@ void steel_gemm_splitk_axpby( float beta = 0.0f) { using namespace mlx::steel; - int _tm = M / 16; - int _tn = N / 16; int _tk = K / 16; int bm = M < 40 ? 16 : 32; @@ -659,16 +657,11 @@ void gemv_axbpy( int in_vector_len = K; int out_vector_len = is_b_matrix ? N : M; - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int mat_ld = is_b_matrix ? ldb : lda; auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - // Determine if inputs have simple batching / broadcasting bool contiguous_kernel = (batch_shape.size() == 1); @@ -964,12 +957,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); array c = c_pre; - int ldc = c.strides()[c.ndim() - 2]; - int fdc = c.strides()[c.ndim() - 1]; int lda = a_cols; int ldb = b_cols; - int ldd = N; ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -1101,10 +1091,6 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask"; - auto get_batch_dims = [](const auto& v) { - return decltype(v){v.begin(), v.end() - 2}; - }; - Shape batch_shape{1}; Strides A_batch_stride{0}; Strides B_batch_stride{0}; @@ -1165,8 +1151,6 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { int in_vector_len = K; int out_vector_len = is_b_matrix ? N : M; - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int mat_ld = is_b_matrix ? b_cols : a_cols; auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 40e7b5bc8..da0160b24 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -102,17 +102,15 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&d, &s](const array& x) -> std::pair { + auto check_input = [&s](const array& x) -> std::pair { if (x.flags().row_contiguous) { return {x, false}; } array x_copy = contiguous_copy_gpu(x, s); return {x_copy, true}; }; - bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[2].is_donatable(); auto [x, copied] = check_input(inputs[0]); - donate_x |= copied; const array& w = inputs[1]; auto [g, g_copied] = check_input(inputs[2]); donate_g |= g_copied; @@ -323,7 +321,6 @@ void LayerNormVJP::eval_gpu( auto [x, copied] = check_input(inputs[0]); donate_x |= copied; const array& w = inputs[1]; - const array& b = inputs[2]; auto [g, g_copied] = check_input(inputs[3]); donate_g |= g_copied; array& gx = outputs[0]; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 92a9a4158..5f6376c5e 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -182,7 +182,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto group_dims = get_block_dims(num_keys, half_size + odd, 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index ca208a36c..0b636c162 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -108,7 +108,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); size_t stride = in.strides()[axis_]; - int bm = 32; int bn = 32; size_t stride_blocks = (stride + bn - 1) / bn; compute_encoder.set_bytes(size, 2); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 48f85635b..8b983a45a 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -26,7 +26,7 @@ void unary_op_gpu_inplace( auto& d = metal::device(s.device); - auto maybe_collapse = [contig, &in, &out]() { + auto maybe_collapse = [contig, &in]() { if (!contig) { return collapse_contiguous_dims(in); } else { diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 90791b02e..d762c8d15 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -213,7 +213,7 @@ namespace detail { CompileMode& compile_mode() { auto get_val = []() { - if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { + if (std::getenv("MLX_DISABLE_COMPILE")) { return CompileMode::disabled; } else { return CompileMode::enabled; @@ -282,7 +282,7 @@ array split_one( } } - return std::move(y); + return y; } template @@ -493,7 +493,6 @@ void compile_simplify( }; auto get_scalar_rep = [](const array& a) { uint64_t v = 0; - int dtype; switch (a.dtype().size()) { case 1: v = *a.data(); diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 7c3dcf095..ac55ea30b 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -365,7 +365,7 @@ std::vector> load_nodes(const char* hostfile) { for (auto& h : hosts) { std::vector host; for (auto& ips : h) { - host.push_back(std::move(parse_address(ips.get()))); + host.push_back(parse_address(ips.get())); } nodes.push_back(std::move(host)); } @@ -554,14 +554,14 @@ class RingGroup : public GroupImpl { // first and accept after. if (rank_ < connect_to) { log_info(verbose_, "Rank", rank_, "accepting"); - sockets_left_ = std::move(accept_connections(nodes[rank_])); + sockets_left_ = accept_connections(nodes[rank_]); log_info(verbose_, "Rank", rank_, "connecting to", connect_to); - sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); + sockets_right_ = make_connections(nodes[connect_to], verbose); } else { log_info(verbose_, "Rank", rank_, "connecting to", connect_to); - sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); + sockets_right_ = make_connections(nodes[connect_to], verbose); log_info(verbose_, "Rank", rank_, "accepting"); - sockets_left_ = std::move(accept_connections(nodes[rank_])); + sockets_left_ = accept_connections(nodes[rank_]); } // Failure if we couldn't make right or left sockets diff --git a/mlx/fast.cpp b/mlx/fast.cpp index e5bd3ad72..0f34aec93 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -418,7 +418,7 @@ array rope( auto positions = multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s); - auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { + auto default_inv_freqs = [&s, &t, base, half_dims]() { return exp( multiply( arange(0, -half_dims, -1, t, s), @@ -687,7 +687,6 @@ array scaled_dot_product_attention( auto v = astype(values, final_type, s); auto fallback = [scale, - final_type, n_q_heads, n_kv_heads, do_causal, @@ -696,8 +695,6 @@ array scaled_dot_product_attention( s](const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; - int B = q.shape(0); - int L = q.shape(2); auto k = inputs[1]; auto v = inputs[2]; if (n_repeats > 1) { diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 649a7f78d..206f6fb31 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -441,7 +441,7 @@ void save_gguf( const char* tensorname = key.c_str(); const uint64_t namelen = key.length(); const uint32_t num_dim = arr.ndim(); - uint64_t dim[num_dim]; + std::vector dim(num_dim); for (int i = 0; i < num_dim; i++) { dim[i] = arr.shape()[num_dim - 1 - i]; } @@ -450,7 +450,7 @@ void save_gguf( tensorname, namelen, num_dim, - dim, + dim.data(), gguf_type.value(), tensor_offset)) { throw std::runtime_error("[save_gguf] gguf_append_tensor_info failed"); diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e8a9e430e..afa8e447a 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -271,7 +271,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) { if (!compute_uv) { return {array( std::move(s_shape), - std::move(a.dtype()), + a.dtype(), std::make_shared(to_stream(s), compute_uv), {a})}; } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a65709752..dc985ce95 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -244,7 +244,7 @@ array linspace( array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { if (dtype == a.dtype()) { - return std::move(a); + return a; } auto copied_shape = a.shape(); // |a| will be moved return array( @@ -2129,7 +2129,6 @@ array min( } array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { - int size = a.size(); auto result = argmin(flatten(a, s), 0, true, s); if (keepdims) { std::vector axes(a.ndim() - 1); @@ -2167,7 +2166,6 @@ array argmin( } array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { - int size = a.size(); auto result = argmax(flatten(a, s), 0, true, s); if (keepdims) { std::vector axes(a.ndim() - 1); @@ -4273,9 +4271,6 @@ array affine_dequantize( if (is_power_of_2(bits)) { std::vector parts; for (int start = 0; start < 32; start += bits) { - int shift_left = 32 - (start + bits); - int shift_right = shift_left + start; - parts.push_back(expand_dims( right_shift( left_shift(w, array(32 - (start + bits), uint32), s), @@ -4883,7 +4878,7 @@ array block_masked_mm( }; // Out mask - if (mask_out.has_value()) { + if (has_out_mask) { array mask_out_p = mask_out.value_or(array({true})); if (in_a_ndim == 1 || in_b_ndim == 1) { std::vector ex_dims; @@ -5015,7 +5010,6 @@ array gather_mm( int M = a.shape(-2); int N = b.shape(-1); - int K = a.shape(-1); std::tie(lhs_indices, rhs_indices) = broadcast_arrays(lhs_indices, rhs_indices, s); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 655a55910..0b335e765 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1105,7 +1105,6 @@ std::pair, std::vector> Concatenate::vmap( // Make sure vmapped arrays have all vmapped axes in the same location and // expand non-vmapped arrays to be compatible with the vmapped ones. std::vector t_inputs; - int N = inputs[first_vmap].shape(out_ax); int axis = axis_ + (axis_ >= out_ax); auto cat_shape = inputs[first_vmap].shape(); for (int i = 0; i < axes.size(); i++) { @@ -3475,7 +3474,6 @@ std::vector GatherQMM::vjp( : std::nullopt; int M = cotan.shape(-2); - int N = cotan.shape(-1); int K = x.shape(-1); bool sorted = left_sorted_ || right_sorted_; @@ -4536,7 +4534,6 @@ std::vector SliceUpdate::vjp( assert(primals.size() == 2); auto& cotan = cotangents[0]; - auto& src = primals[0]; auto& upd = primals[1]; std::vector vjps; @@ -5116,12 +5113,8 @@ std::vector BlockMaskedMM::vjp( const int op_mask_idx = has_out_mask ? 3 : 2; bool needs_lhs_mask_vjp = has_op_mask; bool needs_rhs_mask_vjp = has_op_mask; - bool needs_lhs_vjp = false; - bool needs_rhs_vjp = false; for (auto arg : argnums) { - needs_lhs_vjp = arg == 0; - needs_rhs_vjp = arg == 1; needs_lhs_mask_vjp = arg == op_mask_idx; needs_rhs_mask_vjp = arg == op_mask_idx + 1; } @@ -5346,7 +5339,6 @@ std::vector GatherMM::vjp( auto& rhs_indices = primals[3]; int M = cotan.shape(-2); - int N = cotan.shape(-1); int K = primals[0].shape(-1); bool sorted = left_sorted_ || right_sorted_; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index bcf0cc09f..3ec64feea 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -916,7 +916,7 @@ std::function vmap( int in_axis_b /* = 0 */, int out_axis /* = 0 */) { auto vfun = vmap( - [in_axis_a, in_axis_b, out_axis, fun](const std::vector& inputs) { + [fun](const std::vector& inputs) { return std::vector{fun(inputs[0], inputs[1])}; }, {in_axis_a, in_axis_b}, @@ -929,7 +929,7 @@ std::function vmap( int in_axis /* = 0 */, int out_axis /* = 0 */) { auto vfun = vmap( - [in_axis, out_axis, fun](const std::vector& inputs) { + [fun](const std::vector& inputs) { return std::vector{fun(inputs[0])}; }, {in_axis}, diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 2a850d9f9..1ae259b1e 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -176,17 +176,6 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { namespace { -inline size_t -elem_to_loc(int elem, const Shape& shape, const Strides& strides) { - size_t loc = 0; - for (int i = shape.size() - 1; i >= 0; --i) { - auto q_and_r = ldiv(elem, shape[i]); - loc += q_and_r.rem * strides[i]; - elem = q_and_r.quot; - } - return loc; -} - template void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { int num_print = 3; diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index 625cbf552..58cca348e 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -447,7 +447,7 @@ TEST_CASE("test gpu matmul") { TEST_CASE("test gpu validation") { // Run this test with Metal validation enabled // METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \ - // -tc="test metal validation" \ + // -tc="test metal validation" auto x = array({}); eval(exp(x)); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 878c7101b..2e8bbd692 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3635,7 +3635,6 @@ TEST_CASE("test conv1d") { {1, 3, 2}), float16); - int kernel = 3; int stride = 1; int padding = 1; @@ -3735,7 +3734,6 @@ TEST_CASE("test conv2d") { -0.26912728}, {1, 2, 2, 2}); - std::pair kernel{2, 2}; std::pair stride{1, 1}; std::pair padding{0, 0};