mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
a393435d28
...
343e33b6d5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 |
@@ -99,6 +99,30 @@ const std::filesystem::path& ptx_cache_dir() {
|
||||
return cache;
|
||||
}
|
||||
|
||||
std::filesystem::path get_ptx_path(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name) {
|
||||
#ifdef _WIN32
|
||||
constexpr int max_file_name_length = 140;
|
||||
#else
|
||||
constexpr int max_file_name_length = 245;
|
||||
#endif
|
||||
|
||||
if (module_name.size() <= max_file_name_length) {
|
||||
return cache_dir / (module_name + ".ptx");
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir;
|
||||
int offset = 0;
|
||||
while (module_name.size() - offset > max_file_name_length) {
|
||||
ptx_path /= module_name.substr(offset, max_file_name_length);
|
||||
offset += max_file_name_length;
|
||||
}
|
||||
ptx_path /= module_name.substr(offset) + ".ptx";
|
||||
|
||||
return ptx_path;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
@@ -109,7 +133,7 @@ bool read_cached_ptx(
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
auto ptx_path = get_ptx_path(cache_dir, module_name);
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
if (error) {
|
||||
@@ -122,7 +146,7 @@ bool read_cached_ptx(
|
||||
ptx.resize(ptx_size);
|
||||
ptx_file.read(ptx.data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
@@ -144,16 +168,26 @@ void write_cached_ptx(
|
||||
return;
|
||||
}
|
||||
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
auto ptx_path = get_ptx_path(cache_dir, module_name);
|
||||
|
||||
// Ensure that the directory exists
|
||||
auto parent = ptx_path.parent_path();
|
||||
if (parent != cache_dir) {
|
||||
std::filesystem::create_directories(parent);
|
||||
}
|
||||
|
||||
// Write the compiled code and mangled names
|
||||
std::ofstream ptx_file(ptx_path, std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
}
|
||||
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
|
||||
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||
// Write the generated code
|
||||
std::ofstream source_file(ptx_path.replace_extension(".cu"));
|
||||
source_file << source_code;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -83,7 +81,8 @@ struct RowReduceArgs {
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
__global__ void
|
||||
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[M][N];
|
||||
U accs[M];
|
||||
AlignedVector<T, N> vals[M];
|
||||
AlignedVector<U, M> accs;
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||
const size_t full_blocks = size / (block.size() * N);
|
||||
const size_t final_offset = full_blocks * (block.size() * N);
|
||||
in += start_row * size;
|
||||
in += start_row * size + block.thread_rank() * N;
|
||||
out += start_row;
|
||||
|
||||
if (size % N == 0) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
vals[k] = load_vector<N>(in + k * size, 0);
|
||||
}
|
||||
for (int k = 0; k < M; k++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
|
||||
in += block.size() * N;
|
||||
}
|
||||
|
||||
if (final_offset < size) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + final_offset,
|
||||
vals[k],
|
||||
size,
|
||||
cast_to<T>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
|
||||
? in[k * size + i]
|
||||
: cast_to<T>(init);
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < M; k++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32 * M];
|
||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
if (grid.block_rank() * M + M <= n_rows) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
store_vector(out, 0, accs);
|
||||
} else {
|
||||
short offset = grid.block_rank() * M + M - n_rows;
|
||||
for (int i = offset; i < M; i++) {
|
||||
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM,
|
||||
int N_READS = 4>
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
T* in,
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
@@ -185,36 +163,60 @@ __global__ void row_reduce_looped(
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
total[0] = init;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
||||
const size_t full_blocks = args.row_size / (block.size() * N_READS);
|
||||
const size_t final_offset = full_blocks * (block.size() * N_READS);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
in += block.thread_rank() * N_READS;
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||
vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
// Unaligned reduce
|
||||
if (final_offset < args.row_size) {
|
||||
bool mask[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
mask[i] =
|
||||
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + final_offset,
|
||||
vals,
|
||||
args.row_size - final_offset,
|
||||
cast_to<T>(init));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
{
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
}
|
||||
|
||||
// Aligned case
|
||||
else {
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
// TODO: Maybe block.sync() here?
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32];
|
||||
@@ -234,8 +236,6 @@ void row_reduce_simple(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||
// kernel.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -250,14 +250,15 @@ void row_reduce_simple(
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
@@ -267,6 +268,7 @@ void row_reduce_simple(
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
int size = plan.shape.back();
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||
@@ -282,8 +284,6 @@ void row_reduce_looped(
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::RowReduceArgs args) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -295,34 +295,27 @@ void row_reduce_looped(
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||
kernel = cu::row_reduce_looped<
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim.value,
|
||||
threads_constant.value,
|
||||
N_READS>;
|
||||
block.x = threads_constant.value;
|
||||
});
|
||||
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
|
||||
});
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
std::vector<array> AllReduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>&) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_sum(tangents[0], group(), stream())};
|
||||
@@ -46,7 +46,7 @@ std::vector<array> AllReduce::jvp(
|
||||
std::vector<array> AllReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>& outputs) {
|
||||
return cotangents;
|
||||
}
|
||||
@@ -60,21 +60,30 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>&) {
|
||||
return {all_gather(tangents[0], group(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
auto g = group();
|
||||
auto ndim = primals[0].ndim();
|
||||
Shape starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
if (ndim == 0) {
|
||||
starts.push_back(0);
|
||||
stops.push_back(1);
|
||||
}
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
auto out = slice(cotangents[0], starts, stops);
|
||||
if (ndim == 0) {
|
||||
out = squeeze(out, 0);
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
|
||||
@@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
self.assertTrue(mx.all(y == x[world.rank()]))
|
||||
self.assertTrue(mx.all(z == x[left]))
|
||||
|
||||
def test_all_gather_vjp(self):
|
||||
def fun(x):
|
||||
return mx.distributed.all_gather(x)[0]
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(1.0))
|
||||
if mx.distributed.init().rank() == 0:
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
else:
|
||||
self.assertEqual(dfdx.item(), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user