This commit is contained in:
Awni Hannun 2025-08-22 09:42:46 -07:00
parent 51505c2d5a
commit 2afdf380b1
2 changed files with 7 additions and 5 deletions

View File

@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -15,8 +15,8 @@ void AllReduce::eval_gpu(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(outputs.size() == 1); assert(outputs.size() == 1);
auto set_input_output =
auto set_input_output = [s = stream()](const array& in, array& out) -> std::pair<array, array> { [s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s); copy_gpu(in, out, CopyType::General, s);
return {out, out}; return {out, out};

View File

@ -56,7 +56,7 @@ def parse_hardware_ports(ports_string):
def get_num_nvidia_gpus(): def get_num_nvidia_gpus():
result = run(['nvidia-smi', "-L"], capture_output=True, text=True, check=True) result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n")) return len(result.stdout.strip().split("\n"))
@ -433,7 +433,9 @@ def launch_nccl(parser, hosts, args, command):
base_env = os.environ.copy() base_env = os.environ.copy()
base_env.update( base_env.update(
{ {
"NCCL_DEBUG": base_env.get("NCCL_DEBUG", "DEBUG"), "NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication "NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host, "NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port), "NCCL_PORT": str(master_port),