mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 00:35:27 +08:00
comment
This commit is contained in:
parent
51505c2d5a
commit
2afdf380b1
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -15,8 +15,8 @@ void AllReduce::eval_gpu(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto set_input_output =
|
||||||
auto set_input_output = [s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
copy_gpu(in, out, CopyType::General, s);
|
copy_gpu(in, out, CopyType::General, s);
|
||||||
return {out, out};
|
return {out, out};
|
||||||
|
@ -56,7 +56,7 @@ def parse_hardware_ports(ports_string):
|
|||||||
|
|
||||||
|
|
||||||
def get_num_nvidia_gpus():
|
def get_num_nvidia_gpus():
|
||||||
result = run(['nvidia-smi', "-L"], capture_output=True, text=True, check=True)
|
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
|
||||||
return len(result.stdout.strip().split("\n"))
|
return len(result.stdout.strip().split("\n"))
|
||||||
|
|
||||||
|
|
||||||
@ -433,7 +433,9 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
base_env = os.environ.copy()
|
base_env = os.environ.copy()
|
||||||
base_env.update(
|
base_env.update(
|
||||||
{
|
{
|
||||||
"NCCL_DEBUG": base_env.get("NCCL_DEBUG", "DEBUG"),
|
"NCCL_DEBUG": base_env.get(
|
||||||
|
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
|
||||||
|
),
|
||||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||||
"NCCL_HOST_IP": master_host,
|
"NCCL_HOST_IP": master_host,
|
||||||
"NCCL_PORT": str(master_port),
|
"NCCL_PORT": str(master_port),
|
||||||
|
Loading…
Reference in New Issue
Block a user