mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
9392fc3f88
...
5722c147de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 |
@@ -25,6 +25,11 @@ MLX was developed with contributions from the following individuals:
|
|||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
# Organizations
|
||||||
|
|
||||||
|
MLX has received contributions from the following companies:
|
||||||
|
- NVIDIA Corporation & Affiliates
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
|
|||||||
@@ -30,8 +30,15 @@ SmallSizePool::SmallSizePool() {
|
|||||||
next_free_ = buffer_;
|
next_free_ = buffer_;
|
||||||
|
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
#if CUDART_VERSION >= 13000
|
||||||
|
cudaMemLocation loc;
|
||||||
|
loc.type = cudaMemLocationTypeDevice;
|
||||||
|
loc.id = 0;
|
||||||
|
#else
|
||||||
|
int loc = 0;
|
||||||
|
#endif // CUDART_VERSION >= 13000
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
|
||||||
|
|
||||||
auto curr = next_free_;
|
auto curr = next_free_;
|
||||||
for (size_t i = 1; i < num_blocks; ++i) {
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
|
|||||||
@@ -269,7 +269,13 @@ void CommandEncoder::commit() {
|
|||||||
if (node_count_ > 0) {
|
if (node_count_ > 0) {
|
||||||
if (!from_nodes_.empty()) {
|
if (!from_nodes_.empty()) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
graph_,
|
||||||
|
from_nodes_.data(),
|
||||||
|
to_nodes_.data(),
|
||||||
|
#if CUDART_VERSION >= 13000
|
||||||
|
nullptr, // edgeData
|
||||||
|
#endif // CUDART_VERSION >= 13000
|
||||||
|
from_nodes_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_key_ += ".";
|
graph_key_ += ".";
|
||||||
|
|||||||
@@ -205,8 +205,10 @@ struct Power {
|
|||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
T res = 1;
|
T res = 1;
|
||||||
// Raising an integer to a negative power is undefined
|
// Raising an integer to a negative power is undefined
|
||||||
if (exp < 0) {
|
if constexpr (cuda::std::is_signed_v<T>) {
|
||||||
return 0;
|
if (exp < 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
while (exp) {
|
while (exp) {
|
||||||
if (exp & 1) {
|
if (exp & 1) {
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from select import select
|
|||||||
from subprocess import PIPE, Popen, run
|
from subprocess import PIPE, Popen, run
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Host:
|
class Host:
|
||||||
@@ -437,8 +439,8 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
try:
|
try:
|
||||||
for rank in range(world_size):
|
for rank in range(world_size):
|
||||||
env = base_env.copy()
|
env = base_env.copy()
|
||||||
env["MLX_RANK"] = str(rank)
|
env["MLX_RANK"] = str(rank % args.repeat_hosts)
|
||||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
|
||||||
p = Popen(command, env=env)
|
p = Popen(command, env=env)
|
||||||
procs.append(p)
|
procs.append(p)
|
||||||
|
|
||||||
@@ -708,7 +710,7 @@ def distributed_config():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -780,7 +782,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ auditwheel repair dist/* \
|
|||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
--exclude libcudnn* \
|
--exclude libcudnn* \
|
||||||
|
--exclude libnccl* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|
||||||
@@ -17,7 +18,7 @@ rm "${repaired_wheel}"
|
|||||||
mlx_so="mlx/lib/libmlx.so"
|
mlx_so="mlx/lib/libmlx.so"
|
||||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||||
base="\$ORIGIN/../../nvidia"
|
base="\$ORIGIN/../../nvidia"
|
||||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||||
python ../python/scripts/repair_record.py ${mlx_so}
|
python ../python/scripts/repair_record.py ${mlx_so}
|
||||||
|
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -297,6 +297,7 @@ if __name__ == "__main__":
|
|||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
"nvidia-cudnn-cu12==9.*",
|
"nvidia-cudnn-cu12==9.*",
|
||||||
|
"nvidia-nccl-cu12",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user