From 88f993da3849ef45ea925dd102ca04e8b1b65c36 Mon Sep 17 00:00:00 2001 From: Valentin Roussellet Date: Tue, 24 Dec 2024 07:02:20 -0800 Subject: [PATCH] Explicit parentheses around some logical operators (#1732) * fix some warnings * format --- mlx/array.cpp | 2 +- mlx/backend/common/binary.h | 4 ++-- mlx/backend/common/primitives.cpp | 2 +- mlx/backend/metal/primitives.cpp | 2 +- mlx/ops.cpp | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 9cf36416f..331914518 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -119,7 +119,7 @@ void array::eval() { } bool array::is_tracer() const { - return array_desc_->is_tracer && in_tracing() || retain_graph(); + return (array_desc_->is_tracer && in_tracing()) || retain_graph(); } void array::set_data(allocator::Buffer buffer, Deleter d) { diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 7b9d6ec02..d6879cda4 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -28,8 +28,8 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) { } else if (b.data_size() == 1 && a.flags().contiguous) { bopt = BinaryOpType::VectorScalar; } else if ( - a.flags().row_contiguous && b.flags().row_contiguous || - a.flags().col_contiguous && b.flags().col_contiguous) { + (a.flags().row_contiguous && b.flags().row_contiguous) || + (a.flags().col_contiguous && b.flags().col_contiguous)) { bopt = BinaryOpType::VectorVector; } else { bopt = BinaryOpType::General; diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index c371eb7aa..8c831a9b5 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -605,7 +605,7 @@ void View::eval_cpu(const std::vector& inputs, array& out) { // - type size is the same // - type size is smaller and the last axis is contiguous // - the entire array is row contiguous - if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || in.flags().row_contiguous) { auto strides = in.strides(); for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 012a5217f..4f1518d9a 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -440,7 +440,7 @@ void View::eval_gpu(const std::vector& inputs, array& out) { // - type size is the same // - type size is smaller and the last axis is contiguous // - the entire array is row contiguous - if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || in.flags().row_contiguous) { auto strides = in.strides(); for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3b54b43af..9386d59b9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -157,7 +157,7 @@ array arange( // Check if start and stop specify a valid range because if not, we have to // return an empty array if (std::isinf(step) && - (step > 0 && start < stop || step < 0 && start > stop)) { + ((step > 0 && start < stop) || (step < 0 && start > stop))) { return array({start}, dtype); }