diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 0b396f7f6..e55d960e7 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -625,7 +625,7 @@ class RingGroup : public GroupImpl { std::min( sockets_right_.size() + sockets_left_.size(), nbytes / min_send_size), - 1UL); + size_t(1)); size_t bytes_per_gather = ceildiv(nbytes, n_gathers); std::vector> all_gathers; for (int i = 0; i < n_gathers; i++) { @@ -740,7 +740,7 @@ class RingGroup : public GroupImpl { std::min( sockets_right_.size() + sockets_left_.size(), nbytes / (size_ * min_send_size)), - 1UL); + size_t(1)); size_t step = ceildiv(size, n_reduces); std::vector> all_sums; @@ -777,8 +777,8 @@ class RingGroup : public GroupImpl { // We split the data into `size_` segments of size `segment_size` and each // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. size_t segment_size = ceildiv(data_size, size_); - size_t BUFFER_SIZE = - std::max(32768UL, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); + size_t BUFFER_SIZE = std::max( + size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); // Initial segments @@ -897,7 +897,8 @@ class RingGroup : public GroupImpl { void send(const std::vector& sockets, const char* data, size_t data_size) { - size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); + size_t segment_size = + std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> sends; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) { @@ -914,7 +915,8 @@ class RingGroup : public GroupImpl { } void recv(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); + size_t segment_size = + std::max(size_t(1024), ceildiv(data_size, sockets.size())); std::vector> recvs; for (int i = 0; i < sockets.size(); i++) { if (i * segment_size >= data_size) {