mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
Compare commits
13 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 | ||
![]() |
cde5b4ad80 | ||
![]() |
4f72c66911 | ||
![]() |
960e3f0f05 | ||
![]() |
884af42da2 | ||
![]() |
048fabdabd | ||
![]() |
917252a5a1 | ||
![]() |
1a992e31e8 | ||
![]() |
d2ff04a4f2 | ||
![]() |
015c247393 | ||
![]() |
d3cd26820e | ||
![]() |
91f6c499d7 |
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.19.1)
|
||||
set(MLX_VERSION 0.19.3)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning research on Apple silicon,
|
||||
MLX is an array framework for machine learning on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
@@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[*idx] = x
|
||||
dst[tuple(idx)] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = []
|
||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[*idx] = x
|
||||
def scatter(dst, x, idx, device):
|
||||
dst[tuple(idx)] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(200_000,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||
|
@@ -60,6 +60,7 @@ html_theme_options = {
|
||||
},
|
||||
}
|
||||
|
||||
html_favicon = html_theme_options["logo"]["image_light"]
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
|
@@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
for (auto& s : a.siblings()) {
|
||||
input_map.insert({s.id(), s});
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
|
@@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -110,6 +109,7 @@ void binary_op_gpu_inplace(
|
||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||
}
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (bopt == BinaryOpType::General) {
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@@ -132,7 +132,6 @@ void binary_op_gpu_inplace(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
@@ -142,13 +141,12 @@ void binary_op_gpu_inplace(
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -421,11 +421,12 @@ void Compiled::eval_gpu(
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
|
@@ -752,10 +752,6 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
@@ -769,10 +765,13 @@ void conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
(channels_large || (channels_med && inp_large))) {
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
channels_large) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
|
@@ -120,6 +120,7 @@ void copy_gpu_inplace(
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
@@ -145,7 +146,6 @@ void copy_gpu_inplace(
|
||||
}
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
@@ -155,13 +155,12 @@ void copy_gpu_inplace(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
@@ -205,14 +204,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@@ -181,6 +181,7 @@ Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
|
@@ -136,6 +136,10 @@ class Device {
|
||||
return device_;
|
||||
};
|
||||
|
||||
const std::string& get_architecture() {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
@@ -228,6 +232,7 @@ class Device {
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
|
@@ -113,17 +113,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
@@ -131,21 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
set_vector_bytes(compute_encoder, src.shape(), 2);
|
||||
set_vector_bytes(compute_encoder, src.strides(), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
set_vector_bytes(compute_encoder, slice_sizes_, 5);
|
||||
set_vector_bytes(compute_encoder, axes_, 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 7);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 8);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 9);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Copy src into out
|
||||
auto copy_type =
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Empty update
|
||||
if (inputs.back().size() == 0) {
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
|
||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
||||
// the outermost dims and contiguous since update access won't be raster
|
||||
// order.
|
||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= (axes_[i] == i);
|
||||
}
|
||||
|
||||
// Bail from fast path (1d index specialization) if any of the dims are
|
||||
// broadcasted, since we can't rely on linear indexing in that case.
|
||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
auto idx_to_out = idx_size / out.size();
|
||||
int nwork;
|
||||
if (idx_ndim <= 1 || idx_to_out < 1) {
|
||||
nwork = 1;
|
||||
} else if (idx_to_out <= 4) {
|
||||
nwork = 4;
|
||||
} else if (idx_to_out < 16) {
|
||||
nwork = 8;
|
||||
} else if (idx_to_out < 32) {
|
||||
nwork = 16;
|
||||
} else {
|
||||
nwork = 32;
|
||||
}
|
||||
|
||||
std::string lib_name;
|
||||
@@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
kname << "_" << op_name << "_" << nidx << "_"
|
||||
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
@@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork);
|
||||
return kernel_source.str();
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
if (index_nd1_specialization) {
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
set_vector_bytes(compute_encoder, upd.shape(), 3);
|
||||
set_vector_bytes(compute_encoder, upd.strides(), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
set_vector_bytes(compute_encoder, out.shape(), 7);
|
||||
set_vector_bytes(compute_encoder, out.strides(), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
idx_contigs.push_back(false);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 11);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 12);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 13);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
|
||||
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
auto grid_y = (nthreads / upd_size);
|
||||
grid_y = (grid_y + nwork - 1) / nwork;
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
@@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
upd_shape,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
idx_buffers,
|
||||
gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
@@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
@@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idx_size,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
|
@@ -50,7 +50,9 @@ set(STEEL_HEADERS
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h)
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
|
@@ -25,11 +25,13 @@ METAL_FUNC void gather_impl(
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
|
@@ -9,6 +9,7 @@ struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
|
@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
|
@@ -4,73 +4,54 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (upd_ndim > 1) {
|
||||
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
} else {
|
||||
out_idx += gid.x;
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
bool UPD_ROW_CONTIG,
|
||||
int NWORK>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
const constant int* upd_shape,
|
||||
const constant size_t* upd_strides,
|
||||
const constant size_t& upd_ndim,
|
||||
const constant size_t& upd_size,
|
||||
const constant int* out_shape,
|
||||
const constant size_t* out_strides,
|
||||
const constant size_t& out_ndim,
|
||||
const constant int* axes,
|
||||
const constant size_t& idx_size,
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
size_t out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
out_offset =
|
||||
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
auto upd_idx =
|
||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||
size_t out_idx = out_offset;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
auto upd_idx = ind_idx * upd_size + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
}
|
||||
|
@@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general(
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||
int offset_m = c_row + mma_op.sm;
|
||||
int offset_n = c_col + mma_op.sn;
|
||||
C += offset_n;
|
||||
|
||||
if (offset_n >= gemm_params->N)
|
||||
@@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general(
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum =
|
||||
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * mma_t::TN_stride < diff) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;
|
||||
|
||||
if (j * mma_t::TN_stride + 1 < diff) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
// Apply epilogue and output C
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * mma_t::TN_stride + k) < diff) {
|
||||
C[offset + k] = Epilogue::apply(accum[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -36,11 +36,11 @@
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@@ -18,6 +19,347 @@ using namespace metal;
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename T, int kFragRows_, int kFragCols_>
|
||||
struct BaseMMAFrag {
|
||||
static_assert(
|
||||
kFragRows_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
static_assert(
|
||||
kFragCols_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BaseMMAFrag<T, 8, 8> {
|
||||
STEEL_CONST int kFragRows = 8;
|
||||
STEEL_CONST int kFragCols = 8;
|
||||
|
||||
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
||||
|
||||
STEEL_CONST int kElemRows = 1;
|
||||
STEEL_CONST int kElemCols = 2;
|
||||
|
||||
static_assert(
|
||||
kElemRows * kElemCols == kElemsPerFrag,
|
||||
"MMAFrag shape is not consistent with MMAFrag size");
|
||||
|
||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
||||
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
return short2{fn, fm};
|
||||
}
|
||||
|
||||
template <typename SrcPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename SrcPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void load_safe(
|
||||
thread frag_type& dst,
|
||||
SrcPtrType src,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_safe(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
thread frag_type& B,
|
||||
thread frag_type& C) {
|
||||
mat_type D_mat;
|
||||
mat_type A_mat;
|
||||
mat_type B_mat;
|
||||
mat_type C_mat;
|
||||
|
||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
||||
|
||||
mma(D_mat, A_mat, B_mat, C_mat);
|
||||
|
||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread mat_type& D,
|
||||
thread mat_type& A,
|
||||
thread mat_type& B,
|
||||
thread mat_type& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int kTileRows_,
|
||||
int kTileCols_,
|
||||
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
||||
struct MMATile {
|
||||
using MMAFrag_t = MMAFrag_;
|
||||
using elem_type = T;
|
||||
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
||||
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
||||
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kTileRows = kTileRows_;
|
||||
STEEL_CONST int kTileCols = kTileCols_;
|
||||
|
||||
STEEL_CONST int kRows = kTileRows * kFragRows;
|
||||
STEEL_CONST int kCols = kTileCols * kFragCols;
|
||||
|
||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
||||
|
||||
METAL_FUNC MMATile() thread {}
|
||||
|
||||
METAL_FUNC constexpr void clear() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kNumFrags; ++i) {
|
||||
val_frags[i] = frag_type(0);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr const thread frag_type& frag_at(
|
||||
const short i,
|
||||
const short j) const {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
||||
mat_type val_mat;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
||||
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
||||
}
|
||||
return val_mat;
|
||||
}
|
||||
|
||||
METAL_FUNC thread elem_type* elems() {
|
||||
return reinterpret_cast<thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
METAL_FUNC const thread elem_type* elems() const {
|
||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void load(const threadgroup U* src) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
src[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void store(threadgroup U* dst) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
dst[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void load(const device U* src, const int ld) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store(device U* dst, const int ld) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load_safe(
|
||||
frag_at(i, j),
|
||||
src,
|
||||
ld,
|
||||
Int<1>{},
|
||||
src_tile_dims.y,
|
||||
src_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_safe(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
dst_tile_dims.y,
|
||||
dst_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
METAL_FUNC void tile_matmad(
|
||||
thread MMATile<T, M, N>& D,
|
||||
thread MMATile<U, M, K>& A,
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
B.frag_at(k, n_serp),
|
||||
C.frag_at(m, n_serp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
@@ -33,39 +375,38 @@ template <
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
@@ -75,18 +416,21 @@ struct BlockMMA {
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
@@ -95,47 +439,20 @@ struct BlockMMA {
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
tile_matmad(Ctile, Atile, Btile, Ctile);
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
@@ -144,58 +461,35 @@ struct BlockMMA {
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out D
|
||||
D[offset] = outs[0];
|
||||
D[offset + 1] = outs[1];
|
||||
}
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += (sm + tm) * ldd + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
@@ -203,16 +497,8 @@ struct BlockMMA {
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0]);
|
||||
accum[1] = epilogue_op.apply(accum[1]);
|
||||
}
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,7 +510,7 @@ struct BlockMMA {
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
@@ -232,12 +518,14 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,8 +539,8 @@ struct BlockMMA {
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
@@ -263,22 +551,26 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Read C
|
||||
U c_elems[2] = {0};
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
if ((j * TN_stride + 1) < dst_tile_dims.x) {
|
||||
c_elems[0] = C[offset_c];
|
||||
c_elems[1] = C[offset_c + fdc];
|
||||
} else if ((j * TN_stride) < dst_tile_dims.x) {
|
||||
c_elems[0] = C[offset_c];
|
||||
// Read C
|
||||
U c_elems[kelems] = {0};
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
c_elems[k] = C[offset_c + k * fdc];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
|
||||
accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -292,8 +584,10 @@ struct BlockMMA {
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
@@ -301,18 +595,15 @@ struct BlockMMA {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -326,30 +617,32 @@ struct BlockMMA {
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[offset_d + k] =
|
||||
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
96
mlx/backend/metal/kernels/steel/utils/integral_constant.h
Normal file
96
mlx/backend/metal/kernels/steel/utils/integral_constant.h
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/steel/utils/type_traits.h"
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Integral constant with casting
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant {
|
||||
static constexpr constant T value = v;
|
||||
using value_type = T;
|
||||
using type = integral_constant;
|
||||
|
||||
METAL_FUNC constexpr operator value_type() const noexcept {
|
||||
return value;
|
||||
}
|
||||
|
||||
// METAL_FUNC constexpr value_type operator()() const noexcept {
|
||||
// return value;
|
||||
// }
|
||||
};
|
||||
|
||||
template <bool B>
|
||||
using bool_constant = integral_constant<bool, B>;
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
template <class T>
|
||||
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
|
||||
|
||||
template <class T, T v>
|
||||
struct is_integral<integral_constant<T, v>>
|
||||
: bool_constant<metal::is_integral<T>::value> {};
|
||||
|
||||
template <typename T>
|
||||
constexpr constant bool is_integral_v = is_integral<T>::value;
|
||||
|
||||
template <int val>
|
||||
using Int = integral_constant<int, val>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Binary Operators on Integral constants
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define integral_const_binop(__op__, __operator__) \
|
||||
template <typename T, T tv, typename U, U uv> \
|
||||
METAL_FUNC constexpr auto __operator__( \
|
||||
integral_constant<T, tv>, integral_constant<U, uv>) { \
|
||||
constexpr auto res = tv __op__ uv; \
|
||||
return integral_constant<decltype(res), res>{}; \
|
||||
}
|
||||
|
||||
integral_const_binop(+, operator+);
|
||||
integral_const_binop(-, operator-);
|
||||
integral_const_binop(*, operator*);
|
||||
integral_const_binop(/, operator/);
|
||||
|
||||
integral_const_binop(==, operator==);
|
||||
integral_const_binop(!=, operator!=);
|
||||
integral_const_binop(<, operator<);
|
||||
integral_const_binop(>, operator>);
|
||||
integral_const_binop(<=, operator<=);
|
||||
integral_const_binop(>=, operator>=);
|
||||
|
||||
integral_const_binop(&&, operator&&);
|
||||
integral_const_binop(||, operator||);
|
||||
|
||||
#undef integral_const_binop
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Reduction operators
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC constexpr T sum(T x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, typename... Us>
|
||||
METAL_FUNC constexpr auto sum(T x, Us... us) {
|
||||
return x + sum(us...);
|
||||
}
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
|
||||
#pragma METAL internals : disable
|
55
mlx/backend/metal/kernels/steel/utils/type_traits.h
Normal file
55
mlx/backend/metal/kernels/steel/utils/type_traits.h
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <typename T>
|
||||
struct is_empty : metal::bool_constant<__is_empty(T)> {};
|
||||
|
||||
#ifdef __cpp_variable_templates
|
||||
template <typename T>
|
||||
constexpr constant bool is_empty_v = is_empty<T>::value;
|
||||
#endif
|
||||
|
||||
template <typename... Ts>
|
||||
struct make_void {
|
||||
typedef void type;
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
using void_t = typename make_void<Ts...>::type;
|
||||
|
||||
template <class T>
|
||||
struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
|
||||
|
||||
template <typename T>
|
||||
struct pointer_element {};
|
||||
|
||||
template <typename T>
|
||||
struct pointer_element<thread T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
template <typename T>
|
||||
struct pointer_element<device T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
template <typename T>
|
||||
struct pointer_element<constant T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
template <typename T>
|
||||
struct pointer_element<threadgroup T*> {
|
||||
using type = remove_cv_t<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
@@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Steel matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define GEMM_TPARAM_MACRO(devc) \
|
||||
if (devc == 'g') { /* Small device */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} else if (devc == 'd') { /* Large device */ \
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
|
||||
if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
if (2 * std::max(M, N) > K) { /* Reasonable K */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} else if (!transpose_a && transpose_b) { /* nt with large k */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else { /* nn with large K */ \
|
||||
bm = 32; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} /* float takes default */ \
|
||||
} else { /* smaller matmul */ \
|
||||
if (out.dtype() != float32) { /* half and bfloat */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} else { /* nn */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} else { /* floats */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 32; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 1; \
|
||||
wn = 2; \
|
||||
} else { /* nn */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} else { /* Medium device */ \
|
||||
bm = 64; \
|
||||
bn = 64; \
|
||||
bk = 16; \
|
||||
wm = 2; \
|
||||
wn = 2; \
|
||||
}
|
||||
|
||||
void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -112,19 +189,11 @@ void steel_matmul_regular(
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -903,19 +972,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Regular addmm dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
@@ -1667,19 +1728,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int bm = 64, bn = 64, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
char devc = d.get_architecture().back();
|
||||
GEMM_TPARAM_MACRO(devc)
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
|
@@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// organize into grid nkeys x elem_per_key
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(keys, 0);
|
||||
|
@@ -5,11 +5,10 @@
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
// TODO maybe worth including tvos / visionos
|
||||
#define supported __builtin_available(macOS 15, iOS 18, *)
|
||||
|
||||
ResidencySet::ResidencySet(MTL::Device* d) {
|
||||
if (supported) {
|
||||
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
|
||||
return;
|
||||
} else if (__builtin_available(macOS 15, iOS 18, *)) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto desc = MTL::ResidencySetDescriptor::alloc()->init();
|
||||
NS::Error* error;
|
||||
@@ -27,68 +26,72 @@ ResidencySet::ResidencySet(MTL::Device* d) {
|
||||
}
|
||||
|
||||
void ResidencySet::insert(MTL::Allocation* buf) {
|
||||
if (supported) {
|
||||
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
|
||||
wired_set_->addAllocation(buf);
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else {
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
if (!wired_set_) {
|
||||
return;
|
||||
}
|
||||
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
|
||||
wired_set_->addAllocation(buf);
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else {
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void ResidencySet::erase(MTL::Allocation* buf) {
|
||||
if (supported) {
|
||||
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
|
||||
unwired_set_.erase(it);
|
||||
} else {
|
||||
wired_set_->removeAllocation(buf);
|
||||
wired_set_->commit();
|
||||
}
|
||||
if (!wired_set_) {
|
||||
return;
|
||||
}
|
||||
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
|
||||
unwired_set_.erase(it);
|
||||
} else {
|
||||
wired_set_->removeAllocation(buf);
|
||||
wired_set_->commit();
|
||||
}
|
||||
}
|
||||
|
||||
void ResidencySet::resize(size_t size) {
|
||||
if (supported) {
|
||||
if (capacity_ == size) {
|
||||
return;
|
||||
}
|
||||
capacity_ = size;
|
||||
if (!wired_set_) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t current_size = wired_set_->allocatedSize();
|
||||
if (capacity_ == size) {
|
||||
return;
|
||||
}
|
||||
capacity_ = size;
|
||||
|
||||
if (current_size < size) {
|
||||
// Add unwired allocations to the set
|
||||
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
|
||||
auto buf_size = (*it)->allocatedSize();
|
||||
if (current_size + buf_size > size) {
|
||||
it++;
|
||||
} else {
|
||||
current_size += buf_size;
|
||||
wired_set_->addAllocation(*it);
|
||||
unwired_set_.erase(it++);
|
||||
}
|
||||
size_t current_size = wired_set_->allocatedSize();
|
||||
|
||||
if (current_size < size) {
|
||||
// Add unwired allocations to the set
|
||||
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
|
||||
auto buf_size = (*it)->allocatedSize();
|
||||
if (current_size + buf_size > size) {
|
||||
it++;
|
||||
} else {
|
||||
current_size += buf_size;
|
||||
wired_set_->addAllocation(*it);
|
||||
unwired_set_.erase(it++);
|
||||
}
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else if (current_size > size) {
|
||||
// Remove wired allocations until under capacity
|
||||
auto allocations = wired_set_->allAllocations();
|
||||
auto num_allocations = wired_set_->allocationCount();
|
||||
for (int i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
||||
wired_set_->removeAllocation(buf);
|
||||
current_size -= buf->allocatedSize();
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
wired_set_->commit();
|
||||
}
|
||||
wired_set_->commit();
|
||||
wired_set_->requestResidency();
|
||||
} else if (current_size > size) {
|
||||
// Remove wired allocations until under capacity
|
||||
auto allocations = wired_set_->allAllocations();
|
||||
auto num_allocations = wired_set_->allocationCount();
|
||||
for (int i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
||||
wired_set_->removeAllocation(buf);
|
||||
current_size -= buf->allocatedSize();
|
||||
unwired_set_.insert(buf);
|
||||
}
|
||||
wired_set_->commit();
|
||||
}
|
||||
}
|
||||
|
||||
ResidencySet::~ResidencySet() {
|
||||
if (supported) {
|
||||
if (wired_set_) {
|
||||
wired_set_->release();
|
||||
}
|
||||
}
|
||||
|
@@ -72,6 +72,7 @@ void ternary_op_gpu_inplace(
|
||||
compute_encoder.set_input_array(donate_c ? out : c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (topt == TernaryOpType::General) {
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@@ -93,7 +94,6 @@ void ternary_op_gpu_inplace(
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
||||
}
|
||||
@@ -103,13 +103,12 @@ void ternary_op_gpu_inplace(
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -47,9 +47,7 @@ void unary_op_gpu_inplace(
|
||||
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
||||
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
@@ -75,6 +73,8 @@ void unary_op_gpu_inplace(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
@@ -103,6 +103,9 @@ MTL::Size get_2d_grid_dims(
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
@@ -145,6 +148,9 @@ MTL::Size get_2d_grid_dims(
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return MTL::Size(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
@@ -290,8 +290,15 @@ array bernoulli(
|
||||
throw std::invalid_argument(
|
||||
"[bernoulli] bernoulli probability `p` must be a float type.");
|
||||
}
|
||||
auto res = uniform(shape, p.dtype(), key, s);
|
||||
res = less(res, p, s);
|
||||
|
||||
// Place p on the scale [0, nexthigher(UINT32_MAX)] so that if p >= 1.0 we
|
||||
// get all true and if p <= 0.0 we get all false
|
||||
auto upper = array(
|
||||
std::nextafter(
|
||||
static_cast<float>(std::numeric_limits<uint32_t>::max()),
|
||||
std::numeric_limits<float>::max()),
|
||||
float32);
|
||||
auto res = less(bits(shape, key, s), multiply(p, upper, s), s);
|
||||
if (res.shape() != shape) {
|
||||
throw std::invalid_argument(
|
||||
"[bernoulli] shape of `p` is incompatible with argument `shape`.");
|
||||
|
@@ -686,6 +686,17 @@ std::vector<array> vmap_replace(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int vmap_size = -1;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (in_axes[i] >= 0) {
|
||||
vmap_size = inputs[i].shape(in_axes[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (vmap_size == -1) {
|
||||
throw std::invalid_argument("At least one of in_axes must be non-None.");
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@@ -782,7 +793,11 @@ std::vector<array> vmap_replace(
|
||||
}
|
||||
outputs.push_back(out);
|
||||
} else {
|
||||
outputs.push_back(s_outputs[i]);
|
||||
// When the output has no input dependencies
|
||||
// use the size of the vmapped axis in the inputs to expand the output
|
||||
array output = expand_dims(s_outputs[i], out_axes[i]);
|
||||
output = repeat(output, vmap_size, out_axes[i]);
|
||||
outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
|
@@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
@@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
weight = mx.expand_dims(weight, -1)
|
||||
|
||||
return (
|
||||
(indices_l.astype(mx.int32), 1 - weight),
|
||||
(indices_r.astype(mx.int32), weight),
|
||||
(indices_l.astype(mx.uint32), 1 - weight),
|
||||
(indices_r.astype(mx.uint32), weight),
|
||||
)
|
||||
|
||||
|
||||
@@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
|
||||
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||
|
||||
return (
|
||||
(indices_l1.astype(mx.int32), weight_l1),
|
||||
(indices_r1.astype(mx.int32), weight_r1),
|
||||
(indices_l2.astype(mx.int32), weight_l2),
|
||||
(indices_r2.astype(mx.int32), weight_r2),
|
||||
(indices_l1.astype(mx.uint32), weight_l1),
|
||||
(indices_r1.astype(mx.uint32), weight_r1),
|
||||
(indices_l2.astype(mx.uint32), weight_l2),
|
||||
(indices_r2.astype(mx.uint32), weight_r2),
|
||||
)
|
||||
|
||||
|
||||
|
@@ -848,12 +848,19 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__format__",
|
||||
[](array& a, nb::object format_spec) {
|
||||
if (a.ndim() > 0) {
|
||||
if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) {
|
||||
throw nb::type_error(
|
||||
"unsupported format string passed to mx.array.__format__");
|
||||
} else if (a.ndim() == 0) {
|
||||
auto obj = to_scalar(a);
|
||||
return nb::cast<std::string>(
|
||||
nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr())));
|
||||
} else {
|
||||
nb::gil_scoped_release nogil;
|
||||
std::ostringstream os;
|
||||
os << a;
|
||||
return os.str();
|
||||
}
|
||||
auto obj = to_scalar(a);
|
||||
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
|
||||
})
|
||||
.def(
|
||||
"flatten",
|
||||
|
@@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def fun():
|
||||
a = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
b, _ = mx.divmod(a, a)
|
||||
return mx.log(b)
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def test_add_numpy(self):
|
||||
x = mx.array(1)
|
||||
y = np.array(2, dtype=np.int32)
|
||||
@@ -1885,6 +1898,9 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
s = f"{a:.2f}"
|
||||
|
||||
a = mx.array([1, 2, 3])
|
||||
self.assertEqual(f"{a}", "array([1, 2, 3], dtype=int32)")
|
||||
|
||||
def test_deep_graphs(self):
|
||||
# The following tests should simply run cleanly without a segfault or
|
||||
# crash due to exceeding recursion depth limits.
|
||||
|
@@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_mlx = mx.array(a_np)
|
||||
|
||||
if ax == None:
|
||||
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
|
||||
idx_np = np.random.permutation(a_np.size)
|
||||
values_np = np.random.randint(low=0, high=100, size=(16,))
|
||||
else:
|
||||
shape = list(a_np.shape)
|
||||
shape[ax] = 2
|
||||
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
|
||||
idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
|
||||
idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
|
||||
idx_np = np.broadcast_to(idx_np, shape)
|
||||
values_np = np.random.randint(low=0, high=100, size=shape)
|
||||
|
||||
idx_np.astype(np.int32)
|
||||
|
@@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected[:, 0] = mx.array([1, 2, 3])[:, None]
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_vmap_const_func(self):
|
||||
a = mx.random.uniform(shape=(2, 3, 4))
|
||||
b = mx.random.uniform(shape=(4, 3))
|
||||
|
||||
def const_func(a, b):
|
||||
return mx.array(2)
|
||||
|
||||
out = mx.vmap(const_func, in_axes=(0, None))(a, b)
|
||||
self.assertTrue(mx.array_equal(mx.full((2,), 2), out))
|
||||
out = mx.vmap(const_func, in_axes=(None, 0))(a, b)
|
||||
self.assertTrue(mx.array_equal(mx.full((4,), 2), out))
|
||||
out = mx.vmap(const_func, in_axes=(1, 1))(a, b)
|
||||
self.assertTrue(mx.array_equal(mx.full((3,), 2), out))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(const_func, in_axes=(None, None))(a, b)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
2
setup.py
2
setup.py
@@ -165,7 +165,7 @@ if __name__ == "__main__":
|
||||
|
||||
setup(
|
||||
name="mlx",
|
||||
version=get_version("0.19.1"),
|
||||
version=get_version("0.19.3"),
|
||||
author="MLX Contributors",
|
||||
author_email="mlx@group.apple.com",
|
||||
description="A framework for machine learning on Apple silicon.",
|
||||
|
@@ -34,12 +34,8 @@ TEST_CASE("test simple vmap") {
|
||||
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
|
||||
|
||||
auto vfun = vmap(fun, -1, -1);
|
||||
auto x = zeros({2});
|
||||
CHECK(array_equal(vfun(x), zeros({4, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun);
|
||||
x = zeros({3, 2});
|
||||
auto vfun = vmap(fun);
|
||||
auto x = zeros({3, 2});
|
||||
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
|
||||
|
||||
vfun = vmap(fun, 0, 1);
|
||||
@@ -121,16 +117,9 @@ TEST_CASE("test simple vmap") {
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
|
||||
|
||||
x = array(1.);
|
||||
y = array(1.);
|
||||
vfun = vmap(fun, {-1, -1}, {-1});
|
||||
out = vfun({x, y})[0];
|
||||
CHECK(array_equal(out, array(2.)).item<bool>());
|
||||
|
||||
x = ones({3, 2, 1});
|
||||
y = ones({3, 2, 1});
|
||||
vfun = vmap(vmap(fun));
|
||||
@@ -187,13 +176,6 @@ TEST_CASE("test simple vmap") {
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
|
||||
|
||||
cond = array({true, false});
|
||||
x = array(1.);
|
||||
y = array(2.);
|
||||
vfun = vmap(fun, {-1, -1, -1}, {-1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
|
||||
|
||||
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
|
||||
x = ones({3, 2, 1});
|
||||
y = full({3, 2, 1}, 2);
|
||||
@@ -424,21 +406,6 @@ TEST_CASE("test vmap scatter") {
|
||||
};
|
||||
};
|
||||
|
||||
{
|
||||
// vmap nothing.
|
||||
auto a = zeros({3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
|
||||
auto expected =
|
||||
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
|
||||
// Non-vmapped function output.
|
||||
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 0, scatter on axis 0.
|
||||
auto a = zeros({2, 3, 4});
|
||||
|
Reference in New Issue
Block a user