mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
5 Commits
333ffea273
...
3dcb286baf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 |
@@ -394,7 +394,7 @@ jobs:
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -406,7 +406,6 @@ jobs:
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
|
||||
@@ -15,8 +15,8 @@ void copy_gpu_inplace(
|
||||
int64_t offset_out,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_offset_in,
|
||||
const std::optional<array>& dynamic_offset_out) {
|
||||
std::optional<array> dynamic_offset_in,
|
||||
std::optional<array> dynamic_offset_out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -44,6 +44,16 @@ void copy_gpu_inplace(
|
||||
strides_vec[0]);
|
||||
} else {
|
||||
if (dynamic_offset_in || dynamic_offset_out) {
|
||||
if (!dynamic_offset_in) {
|
||||
dynamic_offset_in = array(0, int64);
|
||||
encoder.add_temporary(*dynamic_offset_in);
|
||||
}
|
||||
if (!dynamic_offset_out) {
|
||||
dynamic_offset_out = array(0, int64);
|
||||
encoder.add_temporary(*dynamic_offset_out);
|
||||
}
|
||||
encoder.set_input_array(*dynamic_offset_in);
|
||||
encoder.set_input_array(*dynamic_offset_out);
|
||||
copy_general_dynamic(
|
||||
encoder,
|
||||
ctype,
|
||||
@@ -54,8 +64,8 @@ void copy_gpu_inplace(
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1],
|
||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||
*dynamic_offset_in,
|
||||
*dynamic_offset_out);
|
||||
} else {
|
||||
copy_general(
|
||||
encoder,
|
||||
|
||||
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append<int32_t>(src.ndim());
|
||||
args.append_ndim(slice_sizes_);
|
||||
args.append(slice_size);
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
args.append(axes_);
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append_ndim(out.shape());
|
||||
args.append_ndim(out.strides());
|
||||
args.append<int32_t>(out.ndim());
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
args.append(axes_);
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
|
||||
@@ -46,6 +46,11 @@ struct KernelArgs {
|
||||
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append(const std::vector<T>& vec) {
|
||||
append(SmallVector<T>(vec.begin(), vec.end()));
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim(SmallVector<T> vec) {
|
||||
|
||||
@@ -24,8 +24,6 @@ namespace mlx::core {
|
||||
}
|
||||
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
@@ -38,4 +41,71 @@ void concatenate_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& s) {
|
||||
Dtype dtype = indices.dtype();
|
||||
int nidx = axes.size();
|
||||
|
||||
std::string module_name =
|
||||
fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx);
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::compute_dynamic_offset<{}, {}>",
|
||||
dtype_to_cuda_type(dtype),
|
||||
nidx);
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::string source = R"(
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T, int NIDX>
|
||||
__global__ void compute_dynamic_offset(
|
||||
const T* indices,
|
||||
int64_t* offset,
|
||||
const __grid_constant__ Strides strides,
|
||||
const __grid_constant__ cuda::std::array<int, NIDX> axes) {
|
||||
int64_t acc = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
acc += indices[i] * strides[axes[i]];
|
||||
}
|
||||
*offset = acc;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
)";
|
||||
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
|
||||
});
|
||||
|
||||
// Prepare output.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.add_temporary(offset);
|
||||
encoder.set_input_array(indices);
|
||||
encoder.set_output_array(offset);
|
||||
|
||||
cu::KernelArgs args;
|
||||
args.append(indices);
|
||||
args.append(offset);
|
||||
args.append_ndim(strides);
|
||||
args.append(axes);
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
encoder.add_kernel_node(kernel, 1, 1, 0, args.args());
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -20,8 +20,8 @@ void copy_gpu_inplace(
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
std::optional<array> dynamic_i_offset = std::nullopt,
|
||||
std::optional<array> dynamic_o_offset = std::nullopt);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
|
||||
@@ -80,6 +80,74 @@ void Depends::eval_gpu(
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* std::optional<array> dynamic_i_offset = */ std::move(in_offset),
|
||||
/* std::optional<array> dynamic_o_offset = */ std::nullopt);
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
auto& start_indices = inputs[2];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy or donate input to output
|
||||
auto s = stream();
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
|
||||
auto out_offset =
|
||||
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* std::optional<array> dynamic_i_offset = */ std::nullopt,
|
||||
/* std::optional<array> dynamic_o_offset = */ std::move(out_offset));
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||
eval(inputs, out);
|
||||
|
||||
@@ -27,4 +27,10 @@ void pad_gpu(
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s);
|
||||
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -20,8 +20,8 @@ void copy_gpu_inplace(
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
std::optional<array> dynamic_i_offset /* = std::nullopt */,
|
||||
std::optional<array> dynamic_o_offset /* = std::nullopt */) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
@@ -25,60 +24,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(step, 1);
|
||||
}
|
||||
|
||||
static array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Kernel to compute offset here.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
auto dtype = indices.dtype();
|
||||
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
||||
auto lib = d.get_library(lib_name, [dtype]() {
|
||||
return fmt::format(
|
||||
R"(
|
||||
[[kernel]] void compute_dynamic_offset_{0}(
|
||||
constant const {1}* indices [[buffer(0)]],
|
||||
device int64_t& offset [[buffer(1)]],
|
||||
constant const int64_t* strides [[buffer(2)]],
|
||||
constant const int* axes [[buffer(3)]],
|
||||
constant const int& n_axes [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
int64_t acc = 0;
|
||||
for (int i = 0; i < n_axes; ++i) {{
|
||||
acc += indices[i] * strides[axes[i]];
|
||||
}}
|
||||
offset = acc;
|
||||
}})",
|
||||
type_to_name(dtype),
|
||||
get_type_string(dtype));
|
||||
});
|
||||
auto kernel = d.get_kernel(lib_name, lib);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(indices, 0);
|
||||
compute_encoder.set_output_array(offset, 1);
|
||||
compute_encoder.set_vector_bytes(strides, 2);
|
||||
compute_encoder.set_vector_bytes(axes, 3);
|
||||
int n_axes = axes.size();
|
||||
compute_encoder.set_bytes(n_axes, 4);
|
||||
MTL::Size dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.dispatch_threads(dims, dims);
|
||||
return offset;
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
@@ -256,72 +201,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
auto& start_indices = inputs[2];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy or donate input to output
|
||||
auto s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
|
||||
auto out_offset =
|
||||
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -2,9 +2,12 @@
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -39,4 +42,58 @@ void concatenate_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Kernel to compute offset here.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
auto dtype = indices.dtype();
|
||||
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
||||
auto lib = d.get_library(lib_name, [dtype]() {
|
||||
return fmt::format(
|
||||
R"(
|
||||
[[kernel]] void compute_dynamic_offset_{0}(
|
||||
constant const {1}* indices [[buffer(0)]],
|
||||
device int64_t& offset [[buffer(1)]],
|
||||
constant const int64_t* strides [[buffer(2)]],
|
||||
constant const int* axes [[buffer(3)]],
|
||||
constant const int& n_axes [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
int64_t acc = 0;
|
||||
for (int i = 0; i < n_axes; ++i) {{
|
||||
acc += indices[i] * strides[axes[i]];
|
||||
}}
|
||||
offset = acc;
|
||||
}})",
|
||||
type_to_name(dtype),
|
||||
get_type_string(dtype));
|
||||
});
|
||||
auto kernel = d.get_kernel(lib_name, lib);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(indices, 0);
|
||||
compute_encoder.set_output_array(offset, 1);
|
||||
compute_encoder.set_vector_bytes(strides, 2);
|
||||
compute_encoder.set_vector_bytes(axes, 3);
|
||||
int n_axes = axes.size();
|
||||
compute_encoder.set_bytes(n_axes, 4);
|
||||
MTL::Size dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.dispatch_threads(dims, dims);
|
||||
return offset;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,8 +1,20 @@
|
||||
if(MLX_BUILD_CUDA)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||
find_package(NCCL REQUIRED)
|
||||
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
find_package(NCCL)
|
||||
if(NCCL_FOUND)
|
||||
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||
else()
|
||||
message(
|
||||
STATUS
|
||||
"NCCL not found, using stubs. To run distributed with NCCL backend, install NCCL."
|
||||
)
|
||||
file(
|
||||
DOWNLOAD
|
||||
"https://raw.githubusercontent.com/NVIDIA/nccl/refs/tags/v2.27.5-1/src/nccl.h.in"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
endif()
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||
endif()
|
||||
|
||||
@@ -76,7 +76,7 @@ def average_gradients(
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
all_reduce_size: int = 32 * 1024**2,
|
||||
communication_type: Optional[mx.Dtype] = None,
|
||||
stream: mx.Stream = mx.cpu,
|
||||
communication_stream: Optional[mx.Stream] = None,
|
||||
):
|
||||
"""Average the gradients across the distributed processes in the passed group.
|
||||
|
||||
@@ -95,7 +95,9 @@ def average_gradients(
|
||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||
type before performing the communication. Typically cast to a
|
||||
smaller float to reduce the communication size. Default: ``None``.
|
||||
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
||||
communication_stream (Optional[mlx.core.Stream]): The stream to usse
|
||||
for the communication. If unspecified the default communication
|
||||
stream is used which can vary by back-end. Default: ``None``.
|
||||
"""
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
@@ -106,7 +108,7 @@ def average_gradients(
|
||||
def _average(x):
|
||||
dt = x.dtype
|
||||
x = x.astype(communication_type) if communication_type is not None else x
|
||||
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
||||
return mx.distributed.all_sum(x, stream=communication_stream).astype(dt) / N
|
||||
|
||||
if all_reduce_size <= 0:
|
||||
return tree_map(_average, gradients)
|
||||
|
||||
@@ -23,6 +23,14 @@ using namespace nb::literals;
|
||||
// Helpers
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool is_str_or_path(nb::object obj) {
|
||||
if (nb::isinstance<nb::str>(obj)) {
|
||||
return true;
|
||||
}
|
||||
nb::object path_type = nb::module_::import_("pathlib").attr("Path");
|
||||
return nb::isinstance(obj, path_type);
|
||||
}
|
||||
|
||||
bool is_istream_object(const nb::object& file) {
|
||||
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
|
||||
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
|
||||
@@ -172,8 +180,9 @@ std::pair<
|
||||
std::unordered_map<std::string, mx::array>,
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return mx::load_safetensors(nb::cast<std::string>(file), s);
|
||||
if (is_str_or_path(file)) { // Assume .safetensors file path string
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
return mx::load_safetensors(file_str, s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
@@ -191,8 +200,9 @@ mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
}
|
||||
|
||||
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return mx::load_gguf(nb::cast<std::string>(file), s);
|
||||
if (is_str_or_path(file)) { // Assume .gguf file path string
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
return mx::load_gguf(file_str, s);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
@@ -201,7 +211,7 @@ mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
|
||||
nb::object file,
|
||||
mx::StreamOrDevice s) {
|
||||
bool own_file = nb::isinstance<nb::str>(file);
|
||||
bool own_file = is_str_or_path(file);
|
||||
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
if (!is_zip_file(zipfile, file)) {
|
||||
@@ -242,8 +252,9 @@ std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
|
||||
}
|
||||
|
||||
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
||||
return mx::load(nb::cast<std::string>(file), s);
|
||||
if (is_str_or_path(file)) { // Assume .npy file path string
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
return mx::load(file_str, s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
|
||||
@@ -264,8 +275,8 @@ LoadOutputTypes mlx_load_helper(
|
||||
mx::StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
fname = nb::cast<std::string>(file);
|
||||
if (is_str_or_path(file)) {
|
||||
fname = nb::cast<std::string>(nb::str(file));
|
||||
} else if (is_istream_object(file)) {
|
||||
fname = nb::cast<std::string>(file.attr("name"));
|
||||
} else {
|
||||
@@ -384,8 +395,9 @@ class PyFileWriter : public mx::io::Writer {
|
||||
};
|
||||
|
||||
void mlx_save_helper(nb::object file, mx::array a) {
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
mx::save(nb::cast<std::string>(file), a);
|
||||
if (is_str_or_path(file)) {
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
mx::save(file_str, a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
@@ -409,8 +421,8 @@ void mlx_savez_helper(
|
||||
// Add .npz to the end of the filename if not already there
|
||||
nb::object file = file_;
|
||||
|
||||
if (nb::isinstance<nb::str>(file_)) {
|
||||
std::string fname = nb::cast<std::string>(file_);
|
||||
if (is_str_or_path(file)) {
|
||||
std::string fname = nb::cast<std::string>(nb::str(file_));
|
||||
|
||||
// Add .npz to file name if it is not there
|
||||
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
|
||||
@@ -473,11 +485,11 @@ void mlx_save_safetensor_helper(
|
||||
metadata_map = std::unordered_map<std::string, std::string>();
|
||||
}
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (is_str_or_path(file)) {
|
||||
{
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
nb::gil_scoped_release nogil;
|
||||
mx::save_safetensors(
|
||||
nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
mx::save_safetensors(file_str, arrays_map, metadata_map);
|
||||
}
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
@@ -496,19 +508,21 @@ void mlx_save_gguf_helper(
|
||||
nb::dict a,
|
||||
std::optional<nb::dict> m) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (is_str_or_path(file)) {
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
|
||||
m.value());
|
||||
{
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
nb::gil_scoped_release nogil;
|
||||
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
mx::save_gguf(file_str, arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
auto file_str = nb::cast<std::string>(nb::str(file));
|
||||
nb::gil_scoped_release nogil;
|
||||
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||
mx::save_gguf(file_str, arrays_map);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -3911,12 +3911,13 @@ void init_ops(nb::module_& m) {
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
"arr"_a,
|
||||
nb::sig("def save(file: str, arr: array) -> None"),
|
||||
nb::sig(
|
||||
"def save(file: Union[file, str, pathlib.Path], arr: array) -> None"),
|
||||
R"pbdoc(
|
||||
Save the array to a binary file in ``.npy`` format.
|
||||
|
||||
Args:
|
||||
file (str): File to which the array is saved
|
||||
file (str, pathlib.Path, file): File to which the array is saved
|
||||
arr (array): Array to be saved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -3927,6 +3928,8 @@ void init_ops(nb::module_& m) {
|
||||
"file"_a,
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
nb::sig(
|
||||
"def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
|
||||
R"pbdoc(
|
||||
Save several arrays to a binary file in uncompressed ``.npz``
|
||||
format.
|
||||
@@ -3946,7 +3949,7 @@ void init_ops(nb::module_& m) {
|
||||
mx.savez("model.npz", **dict(flat_params))
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
@@ -3959,12 +3962,13 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"args"_a,
|
||||
"kwargs"_a,
|
||||
nb::sig("def savez_compressed(file: str, *args, **kwargs)"),
|
||||
nb::sig(
|
||||
"def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"),
|
||||
R"pbdoc(
|
||||
Save several arrays to a binary file in compressed ``.npz`` format.
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
file (file, str, pathlib.Path): Path to file to which the arrays are saved.
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
@@ -3978,7 +3982,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
R"pbdoc(
|
||||
Load array(s) from a binary file.
|
||||
|
||||
@@ -3986,7 +3990,7 @@ void init_ops(nb::module_& m) {
|
||||
``.gguf``.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
format (str, optional): Format of the file. If ``None``, the
|
||||
format is inferred from the file extension. Supported formats:
|
||||
``npy``, ``npz``, and ``safetensors``. Default: ``None``.
|
||||
@@ -4012,7 +4016,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
"def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@@ -4021,7 +4025,7 @@ void init_ops(nb::module_& m) {
|
||||
information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved.
|
||||
metadata (dict(str, str), optional): The dictionary of
|
||||
@@ -4034,7 +4038,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
"def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
@@ -4043,7 +4047,7 @@ void init_ops(nb::module_& m) {
|
||||
more information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
file (file, str, pathlib.Path): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to
|
||||
be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
cuda_skip = {
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
# Gather matmul NYI
|
||||
|
||||
@@ -65,21 +65,21 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||
mx.distributed.all_sum = new_all_sum
|
||||
try:
|
||||
grads = [mx.ones(10) for i in range(10)]
|
||||
new_grads = average_gradients(grads, stream=mx.gpu)
|
||||
new_grads = average_gradients(grads)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 1)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
@@ -91,7 +91,6 @@ class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||
grads,
|
||||
all_reduce_size=2 * 50,
|
||||
communication_type=mx.float16,
|
||||
stream=mx.gpu,
|
||||
)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
|
||||
Reference in New Issue
Block a user