mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
parent
59247c2b62
commit
9f0d5c12fc
@ -494,7 +494,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@ -509,14 +509,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@ -530,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
@ -257,7 +257,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@ -272,15 +272,15 @@ void Axpby::eval_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
@ -295,7 +295,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
@ -242,6 +242,9 @@ void MetalAllocator::clear_cache() {
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (buf == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.erase(buf);
|
||||
active_memory_ -= buf->length();
|
||||
|
@ -92,7 +92,7 @@ void binary_op_gpu_inplace(
|
||||
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
|
||||
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
@ -117,19 +117,15 @@ void binary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||
compute_encoder.set_vector_bytes(shape, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
compute_encoder.set_bytes<int>(ndim, arg_idx++);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
|
||||
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
|
||||
}
|
||||
|
||||
if (thread_group_size != 1024) {
|
||||
@ -137,7 +133,7 @@ void binary_op_gpu_inplace(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
@ -147,7 +143,7 @@ void binary_op_gpu_inplace(
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -373,7 +373,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
@ -394,8 +394,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
compute_encoder.set_vector_bytes(in_strides, cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
@ -408,14 +407,13 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
compute_encoder.set_bytes(ndim, cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
@ -427,7 +425,7 @@ void Compiled::eval_gpu(
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
@ -445,7 +443,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,12 +44,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
@ -60,7 +60,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
@ -122,12 +122,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
<< N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
@ -138,7 +138,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
@ -237,7 +237,7 @@ void slow_conv_2D_gpu(
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
@ -252,8 +252,8 @@ void slow_conv_2D_gpu(
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@ -368,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@ -506,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
@ -523,17 +523,15 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(jump_params, 5);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
compute_encoder.set_vector_bytes(base_h, 6);
|
||||
compute_encoder.set_vector_bytes(base_w, 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@ -622,18 +620,18 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
compute_encoder.set_bytes(C_c, 2);
|
||||
compute_encoder.set_bytes(O_c, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@ -650,18 +648,17 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@ -698,18 +695,17 @@ void winograd_conv_2D_gpu(
|
||||
<< bc;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,7 @@ void copy_gpu_inplace(
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
@ -125,11 +125,11 @@ void copy_gpu_inplace(
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
if (ndim > 3) {
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
compute_encoder.set_vector_bytes(shape, ndim, 2);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
@ -141,7 +141,7 @@ void copy_gpu_inplace(
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
}
|
||||
|
||||
@ -152,7 +152,7 @@ void copy_gpu_inplace(
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
if (thread_group_size > nthreads) {
|
||||
@ -161,7 +161,7 @@ void copy_gpu_inplace(
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,7 +199,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
type_to_name(val) + type_to_name(out);
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
@ -212,7 +212,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
@ -75,7 +75,7 @@ void CustomKernel::eval_gpu(
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@ -171,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() {
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
void CommandEncoder::dispatch_threadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
void CommandEncoder::dispatch_threads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
@ -298,7 +298,7 @@ void Device::end_encoding(int index) {
|
||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||
// If we've already waited on a fence, don't wait on it again.
|
||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||
enc->waitForFence(it->second->fence);
|
||||
enc.wait_for_fence(it->second->fence);
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
@ -307,7 +307,7 @@ void Device::end_encoding(int index) {
|
||||
stream.outputs[out] = stream.fence;
|
||||
}
|
||||
}
|
||||
enc->updateFence(stream.fence->fence);
|
||||
enc.update_fence(stream.fence->fence);
|
||||
stream.buffer->addCompletedHandler(
|
||||
[&stream,
|
||||
waiting_on = std::move(waiting_on),
|
||||
|
@ -58,16 +58,43 @@ struct CommandEncoder {
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc_;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
||||
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
||||
enc_->setComputePipelineState(kernel);
|
||||
}
|
||||
|
||||
void wait_for_fence(MTL::Fence* fence) {
|
||||
enc_->waitForFence(fence);
|
||||
}
|
||||
|
||||
void update_fence(MTL::Fence* fence) {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T* v, int n, int idx) {
|
||||
return enc_->setBytes(v, n * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_bytes(const T& v, int idx) {
|
||||
return enc_->setBytes(&v, sizeof(T), idx);
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
@ -699,7 +699,7 @@ void fft_op(
|
||||
auto kernel =
|
||||
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
@ -711,9 +711,9 @@ void fft_op(
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(n, 4);
|
||||
compute_encoder.set_bytes(plan.bluestein_n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
@ -723,22 +723,22 @@ void fft_op(
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
compute_encoder.set_bytes(plan.rader_n, 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(four_step_params.n1, 2);
|
||||
compute_encoder.set_bytes(four_step_params.n2, 3);
|
||||
compute_encoder.set_bytes(total_batch_size, 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
compute_encoder.set_bytes(n, 2);
|
||||
compute_encoder.set_bytes(total_batch_size, 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
compute_encoder.set_bytes(scale, 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
};
|
||||
|
||||
if (m > 1) {
|
||||
|
@ -87,7 +87,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
@ -131,20 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
set_vector_bytes(compute_encoder, src.shape(), 2);
|
||||
set_vector_bytes(compute_encoder, src.strides(), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
set_vector_bytes(compute_encoder, slice_sizes_, 5);
|
||||
set_vector_bytes(compute_encoder, axes_, 6);
|
||||
compute_encoder.set_vector_bytes(src.shape(), 2);
|
||||
compute_encoder.set_vector_bytes(src.strides(), 3);
|
||||
compute_encoder.set_bytes(ndim, 4);
|
||||
compute_encoder.set_vector_bytes(slice_sizes_, 5);
|
||||
compute_encoder.set_vector_bytes(axes_, 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 7);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 8);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 9);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 7);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 8);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 9);
|
||||
compute_encoder.set_bytes(idx_ndim, 10);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@ -152,7 +152,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -289,7 +289,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t nthreads = upd.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
@ -323,14 +323,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
} else {
|
||||
set_vector_bytes(compute_encoder, upd.shape(), 3);
|
||||
set_vector_bytes(compute_encoder, upd.strides(), 4);
|
||||
compute_encoder.set_vector_bytes(upd.shape(), 3);
|
||||
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
compute_encoder.set_bytes(upd_ndim, 5);
|
||||
compute_encoder.set_bytes(upd_size, 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
@ -338,14 +338,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
compute_encoder.set_bytes(stride_, 8);
|
||||
} else {
|
||||
set_vector_bytes(compute_encoder, out.shape(), 7);
|
||||
set_vector_bytes(compute_encoder, out.strides(), 8);
|
||||
compute_encoder.set_vector_bytes(out.shape(), 7);
|
||||
compute_encoder.set_vector_bytes(out.strides(), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
compute_encoder.set_bytes(out_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(axes_, 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
@ -355,11 +355,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_strides.push_back(0);
|
||||
idx_contigs.push_back(false);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 11);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 12);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 13);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
|
||||
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
|
||||
compute_encoder.set_vector_bytes(idx_shapes, 11);
|
||||
compute_encoder.set_vector_bytes(idx_strides, 12);
|
||||
compute_encoder.set_vector_bytes(idx_contigs, 13);
|
||||
compute_encoder.set_bytes(idx_ndim, 14);
|
||||
compute_encoder.set_bytes(idx_size, 15);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@ -375,7 +375,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -17,12 +17,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
@ -84,13 +87,15 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_inv_mean[1];
|
||||
threadgroup float local_sums[SIMD_SIZE];
|
||||
|
||||
float acc = 0;
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
@ -376,8 +381,6 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
@ -407,8 +410,6 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
|
@ -249,7 +249,7 @@ void steel_matmul_regular(
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@ -288,12 +288,12 @@ void steel_matmul_regular(
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@ -390,7 +390,7 @@ void steel_matmul(
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@ -416,34 +416,30 @@ void steel_matmul(
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto c_split_buf =
|
||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||
compute_encoder->memoryBarrier(resources, 1);
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split);
|
||||
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, false);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@ -625,7 +621,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@ -635,16 +631,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@ -822,7 +818,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@ -833,23 +829,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 8);
|
||||
compute_encoder.set_bytes(alpha_, 7);
|
||||
compute_encoder.set_bytes(beta_, 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
set_vector_bytes(compute_encoder, C_batch_stride, 13);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
compute_encoder.set_bytes(bias_stride, 14);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@ -907,7 +903,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@ -933,8 +929,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
@ -943,25 +939,25 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, true);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 9);
|
||||
compute_encoder.set_bytes(ldc, 6);
|
||||
compute_encoder.set_bytes(fdc, 7);
|
||||
compute_encoder.set_bytes(alpha_, 8);
|
||||
compute_encoder.set_bytes(beta_, 9);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@ -1032,7 +1028,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@ -1083,13 +1079,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(params, 5);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
@ -1304,7 +1300,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
contiguous_kernel);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@ -1372,18 +1368,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 23);
|
||||
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
|
||||
compute_encoder.set_vector_bytes(mask_strides, 23);
|
||||
compute_encoder.set_vector_bytes(mask_batch_strides, 24);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@ -1423,7 +1419,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@ -1486,14 +1482,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||
compute_encoder.set_vector_bytes(mask_strides, 13);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
@ -1687,7 +1683,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@ -1697,28 +1693,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 11);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
|
||||
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@ -1788,7 +1784,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@ -1827,10 +1823,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
@ -1845,11 +1841,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
operand_batch_ndim.push_back(0);
|
||||
|
||||
set_vector_bytes(compute_encoder, operand_shape, 13);
|
||||
set_vector_bytes(compute_encoder, operand_strides, 14);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
|
||||
compute_encoder.set_vector_bytes(operand_shape, 13);
|
||||
compute_encoder.set_vector_bytes(operand_strides, 14);
|
||||
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@ -78,18 +78,15 @@ void RMSNorm::eval_gpu(
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
16 * 8, 0); // minimum of 16 bytes
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(eps_, 3);
|
||||
compute_encoder.set_bytes(axis_size, 4);
|
||||
compute_encoder.set_bytes(w_stride, 5);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@ -183,16 +180,16 @@ void RMSNormVJP::eval_gpu(
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(eps_, 5);
|
||||
compute_encoder.set_bytes(axis_size, 6);
|
||||
compute_encoder.set_bytes(w_stride, 7);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
@ -273,17 +270,17 @@ void LayerNorm::eval_gpu(
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(b, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
|
||||
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(eps_, 4);
|
||||
compute_encoder.set_bytes(axis_size, 5);
|
||||
compute_encoder.set_bytes(w_stride, 6);
|
||||
compute_encoder.set_bytes(b_stride, 7);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@ -395,16 +392,16 @@ void LayerNormVJP::eval_gpu(
|
||||
}
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(eps_, 5);
|
||||
compute_encoder.set_bytes(axis_size, 6);
|
||||
compute_encoder.set_bytes(w_stride, 7);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
|
@ -17,10 +17,10 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
|
||||
enc->setBytes(&start, sizeof(T), 0);
|
||||
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(start, 0);
|
||||
T step = next - start;
|
||||
enc->setBytes(&step, sizeof(T), 1);
|
||||
enc.set_bytes(step, 1);
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -37,7 +37,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_: // unsupported
|
||||
@ -80,7 +80,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -129,25 +129,25 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
compute_encoder.set_bytes(shape_, 2);
|
||||
compute_encoder.set_bytes(stride_, 3);
|
||||
compute_encoder.set_bytes(stride_, 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder.set_vector_bytes(shape, 2);
|
||||
compute_encoder.set_vector_bytes(in_strides, 3);
|
||||
compute_encoder.set_vector_bytes(out_strides, 4);
|
||||
}
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
compute_encoder.set_bytes(axis_stride, 6);
|
||||
compute_encoder.set_bytes(axis_size, 7);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,22 +275,20 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(keys, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&odd, sizeof(bool), 2);
|
||||
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
|
||||
compute_encoder.set_bytes(odd, 2);
|
||||
compute_encoder.set_bytes(bytes_per_key, 3);
|
||||
|
||||
if (!keys.flags().row_contiguous) {
|
||||
int ndim = keys.ndim();
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
keys.shape().data(), keys.ndim() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
|
||||
compute_encoder.set_bytes(ndim, 4);
|
||||
compute_encoder.set_vector_bytes(keys.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(keys.strides(), 6);
|
||||
}
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -101,31 +101,31 @@ void launch_qmm(
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(D, 5);
|
||||
compute_encoder.set_bytes(O, 6);
|
||||
|
||||
int offset = 7;
|
||||
if (matrix) {
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(B, 7);
|
||||
offset += 1;
|
||||
}
|
||||
|
||||
if (batched || gather) {
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
|
||||
set_vector_bytes(compute_encoder, x_shape, offset + 1);
|
||||
set_vector_bytes(compute_encoder, x_strides, offset + 2);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
|
||||
set_vector_bytes(compute_encoder, w_shape, offset + 4);
|
||||
set_vector_bytes(compute_encoder, w_strides, offset + 5);
|
||||
set_vector_bytes(compute_encoder, s_strides, offset + 6);
|
||||
set_vector_bytes(compute_encoder, b_strides, offset + 7);
|
||||
compute_encoder.set_bytes(x_batch_ndims, offset);
|
||||
compute_encoder.set_vector_bytes(x_shape, offset + 1);
|
||||
compute_encoder.set_vector_bytes(x_strides, offset + 2);
|
||||
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
|
||||
compute_encoder.set_vector_bytes(w_shape, offset + 4);
|
||||
compute_encoder.set_vector_bytes(w_strides, offset + 5);
|
||||
compute_encoder.set_vector_bytes(s_strides, offset + 6);
|
||||
compute_encoder.set_vector_bytes(b_strides, offset + 7);
|
||||
}
|
||||
if (gather) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
@ -137,15 +137,15 @@ void launch_qmm(
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
|
||||
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
|
||||
compute_encoder.set_bytes(batch_ndims, offset + 8);
|
||||
compute_encoder.set_vector_bytes(batch_shape, offset + 9);
|
||||
compute_encoder.set_input_array(lhs_indices, offset + 10);
|
||||
compute_encoder.set_input_array(rhs_indices, offset + 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
|
||||
compute_encoder.set_vector_bytes(lhs_strides, offset + 12);
|
||||
compute_encoder.set_vector_bytes(rhs_strides, offset + 13);
|
||||
}
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
@ -236,27 +236,27 @@ void qvm_split_k(
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(intermediate, 4);
|
||||
compute_encoder->setBytes(&split_D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(split_D, 5);
|
||||
compute_encoder.set_bytes(O, 6);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
|
||||
set_vector_bytes(compute_encoder, x_shape, 8);
|
||||
set_vector_bytes(compute_encoder, x_strides, 9);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, w_shape, 11);
|
||||
set_vector_bytes(compute_encoder, w_strides, 12);
|
||||
set_vector_bytes(compute_encoder, s_strides, 13);
|
||||
set_vector_bytes(compute_encoder, b_strides, 14);
|
||||
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
|
||||
compute_encoder.set_bytes(x_batch_ndims, 7);
|
||||
compute_encoder.set_vector_bytes(x_shape, 8);
|
||||
compute_encoder.set_vector_bytes(x_strides, 9);
|
||||
compute_encoder.set_bytes(w_batch_ndims, 10);
|
||||
compute_encoder.set_vector_bytes(w_shape, 11);
|
||||
compute_encoder.set_vector_bytes(w_strides, 12);
|
||||
compute_encoder.set_vector_bytes(s_strides, 13);
|
||||
compute_encoder.set_vector_bytes(b_strides, 14);
|
||||
compute_encoder.set_bytes(final_block_size, 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
int axis = intermediate.ndim() - 3;
|
||||
@ -447,7 +447,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), kernel_func, type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
@ -471,7 +471,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
}
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
@ -67,17 +67,14 @@ struct RowReduceArgs {
|
||||
strides.push_back(0);
|
||||
}
|
||||
|
||||
compute_encoder->setBytes(&row_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
compute_encoder->setBytes(
|
||||
reduce_shape.data(), reduce_shape.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
|
||||
compute_encoder.set_bytes(row_size, 2);
|
||||
compute_encoder.set_bytes(non_row_reductions, 3);
|
||||
compute_encoder.set_vector_bytes(shape, 4);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_bytes(ndim, 6);
|
||||
compute_encoder.set_vector_bytes(reduce_shape, 7);
|
||||
compute_encoder.set_vector_bytes(reduce_strides, 8);
|
||||
compute_encoder.set_bytes(reduce_ndim, 9);
|
||||
|
||||
if (reduce_ndim == 0) {
|
||||
reduce_shape.pop_back();
|
||||
@ -166,18 +163,15 @@ struct ColReduceArgs {
|
||||
strides.push_back(0);
|
||||
}
|
||||
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
compute_encoder->setBytes(
|
||||
reduce_shape.data(), reduce_shape.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 10);
|
||||
compute_encoder.set_bytes(reduction_size, 2);
|
||||
compute_encoder.set_bytes(reduction_stride, 3);
|
||||
compute_encoder.set_vector_bytes(shape, 4);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_bytes(ndim, 6);
|
||||
compute_encoder.set_vector_bytes(reduce_shape, 7);
|
||||
compute_encoder.set_vector_bytes(reduce_strides, 8);
|
||||
compute_encoder.set_bytes(reduce_ndim, 9);
|
||||
compute_encoder.set_bytes(non_col_reductions, 10);
|
||||
|
||||
if (reduce_ndim == 0) {
|
||||
reduce_shape.pop_back();
|
||||
@ -256,9 +250,9 @@ void init_reduce(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_output_array(out, 0);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void all_reduce_dispatch(
|
||||
@ -273,7 +267,7 @@ void all_reduce_dispatch(
|
||||
const std::string func_name = "all_reduce";
|
||||
kname << func_name << "_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t in_size = in.size();
|
||||
|
||||
@ -285,9 +279,9 @@ void all_reduce_dispatch(
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 3);
|
||||
compute_encoder.dispatchThreads(grid_dims, grid_dims);
|
||||
compute_encoder.set_bytes(in_size, 2);
|
||||
compute_encoder.set_bytes(in_size, 3);
|
||||
compute_encoder.dispatch_threads(grid_dims, grid_dims);
|
||||
}
|
||||
|
||||
// We need multiple threadgroups so we 'll do it in 2 passes.
|
||||
@ -319,24 +313,24 @@ void all_reduce_dispatch(
|
||||
MTL::Size group_dims(threadgroup_size, 1, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&row_size, sizeof(size_t), 3);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(in_size, 2);
|
||||
compute_encoder.set_bytes(row_size, 3);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// 2nd pass
|
||||
std::ostringstream kname_2nd_pass;
|
||||
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
|
||||
auto kernel_2nd_pass = get_reduce_kernel(
|
||||
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
|
||||
compute_encoder->setComputePipelineState(kernel_2nd_pass);
|
||||
compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);
|
||||
size_t intermediate_size = n_rows;
|
||||
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 3);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(intermediate_size, 2);
|
||||
compute_encoder.set_bytes(intermediate_size, 3);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,7 +349,7 @@ void row_reduce_small(
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Figure out the grid dims
|
||||
MTL::Size grid_dims;
|
||||
@ -375,7 +369,7 @@ void row_reduce_small(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void row_reduce_simple(
|
||||
@ -391,7 +385,7 @@ void row_reduce_simple(
|
||||
const std::string func_name = "row_reduce_simple";
|
||||
kname << func_name << "_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Figure out the grid dims
|
||||
size_t row_size = args.row_size;
|
||||
@ -410,9 +404,9 @@ void row_reduce_simple(
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&row_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(row_size, 2);
|
||||
compute_encoder.set_bytes(out_size, 3);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void row_reduce_looped(
|
||||
@ -430,7 +424,7 @@ void row_reduce_looped(
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Figure out the grid
|
||||
auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
|
||||
@ -443,7 +437,7 @@ void row_reduce_looped(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
@ -495,7 +489,7 @@ void strided_reduce_small(
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
const int n_reads = 4;
|
||||
size_t reduction_stride_blocks =
|
||||
@ -517,7 +511,7 @@ void strided_reduce_small(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_longcolumn(
|
||||
@ -568,14 +562,14 @@ void strided_reduce_longcolumn(
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(out_size, 11);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Make the 2nd pass arguments and grid_dims
|
||||
ColReduceArgs second_args(intermediate);
|
||||
@ -599,12 +593,12 @@ void strided_reduce_longcolumn(
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
second_args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_looped(
|
||||
@ -639,13 +633,13 @@ void strided_reduce_looped(
|
||||
<< op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_2pass(
|
||||
@ -692,14 +686,14 @@ void strided_reduce_2pass(
|
||||
<< op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Launch
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
args.encode(compute_encoder);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(out_size, 11);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Make the 2nd pass arguments and grid_dims
|
||||
ColReduceArgs second_args(intermediate);
|
||||
@ -721,12 +715,12 @@ void strided_reduce_2pass(
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
second_args.encode(compute_encoder);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
|
@ -75,24 +75,24 @@ void RoPE::eval_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 3);
|
||||
compute_encoder.set_bytes(offset_, 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
|
||||
compute_encoder.set_bytes(out_strides, 1, 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
} else {
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
|
||||
compute_encoder.set_bytes(strides, 3, 4);
|
||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||
compute_encoder.set_bytes(n_batch, 6);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
@ -104,11 +104,11 @@ void RoPE::eval_gpu(
|
||||
auto& freqs = inputs[1];
|
||||
compute_encoder.set_input_array(freqs, 10);
|
||||
auto freq_stride = freqs.strides()[0];
|
||||
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
|
||||
compute_encoder.set_bytes(freq_stride, 11);
|
||||
} else {
|
||||
compute_encoder->setBytes(&base, sizeof(float), 10);
|
||||
compute_encoder.set_bytes(base, 10);
|
||||
}
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -59,7 +59,7 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_self_attention.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
uint hidden_dim = q.shape(-1);
|
||||
uint qseq = q.shape(-2);
|
||||
@ -129,17 +129,14 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(MLXFastAttentionParams), 4);
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void sdpa_vector(
|
||||
@ -170,21 +167,21 @@ void sdpa_vector(
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 8);
|
||||
compute_encoder.set_bytes(gqa_factor, 4);
|
||||
compute_encoder.set_bytes(N, 5);
|
||||
compute_encoder.set_bytes(k_stride, 6);
|
||||
compute_encoder.set_bytes(v_stride, 7);
|
||||
compute_encoder.set_bytes(scale, 8);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -68,12 +68,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder.set_bytes(size, 2);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
@ -95,10 +95,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
@ -107,9 +107,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
size_t stride_blocks = (stride + bn - 1) / bn;
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
|
||||
compute_encoder.set_bytes(size, 2);
|
||||
compute_encoder.set_bytes(stride, 3);
|
||||
compute_encoder.set_bytes(stride_blocks, 4);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
@ -125,7 +125,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@ -81,12 +81,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(axis_size, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
@ -68,29 +68,29 @@ void single_block_sort(
|
||||
|
||||
// Prepare command encoder
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set inputs
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 2);
|
||||
compute_encoder.set_bytes(in_stride_sorted_axis, 3);
|
||||
compute_encoder.set_bytes(out_stride_sorted_axis, 4);
|
||||
|
||||
if (contiguous) {
|
||||
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_stride_segment_axis, 5);
|
||||
compute_encoder.set_bytes(out_stride_segment_axis, 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
|
||||
compute_encoder.set_bytes(nc_dim, 5);
|
||||
compute_encoder.set_vector_bytes(nc_shape, 6);
|
||||
compute_encoder.set_vector_bytes(in_nc_str, 7);
|
||||
compute_encoder.set_vector_bytes(out_nc_str, 8);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void multi_block_sort(
|
||||
@ -152,22 +152,21 @@ void multi_block_sort(
|
||||
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(dev_vals_0, 1);
|
||||
compute_encoder.set_output_array(dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 3);
|
||||
compute_encoder.set_bytes(stride_sorted_axis, 4);
|
||||
compute_encoder.set_bytes(nc_dim, 5);
|
||||
compute_encoder.set_vector_bytes(nc_shape, 6);
|
||||
compute_encoder.set_vector_bytes(nc_str, 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merges
|
||||
@ -194,19 +193,19 @@ void multi_block_sort(
|
||||
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 3);
|
||||
compute_encoder.set_bytes(merge_tiles, 4);
|
||||
compute_encoder.set_bytes(n_blocks, 5);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merge
|
||||
@ -217,21 +216,21 @@ void multi_block_sort(
|
||||
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(dev_vals_out, 3);
|
||||
compute_encoder.set_output_array(dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 5);
|
||||
compute_encoder.set_bytes(merge_tiles, 6);
|
||||
compute_encoder.set_bytes(n_blocks, 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ void ternary_op_gpu_inplace(
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
bool donate_c = c.data_shared_ptr() == nullptr;
|
||||
@ -80,18 +80,18 @@ void ternary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder.set_vector_bytes(shape, 4);
|
||||
compute_encoder.set_vector_bytes(strides_a, 5);
|
||||
compute_encoder.set_vector_bytes(strides_b, 6);
|
||||
compute_encoder.set_vector_bytes(strides_c, 7);
|
||||
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder.set_vector_bytes(strides_a, 4);
|
||||
compute_encoder.set_vector_bytes(strides_b, 5);
|
||||
compute_encoder.set_vector_bytes(strides_c, 6);
|
||||
}
|
||||
|
||||
if (thread_group_size != 1024) {
|
||||
@ -99,7 +99,7 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
@ -109,7 +109,7 @@ void ternary_op_gpu_inplace(
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ void unary_op_gpu_inplace(
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
@ -58,16 +58,16 @@ void unary_op_gpu_inplace(
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(strides.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
compute_encoder.set_vector_bytes(shape, 2);
|
||||
compute_encoder.set_vector_bytes(strides, 3);
|
||||
compute_encoder.set_bytes(ndim, 4);
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::unary] Must use 1024 sized block");
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@ -75,7 +75,7 @@ void unary_op_gpu_inplace(
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,23 +8,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using metal::CommandEncoder;
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
CommandEncoder& enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void
|
||||
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a);
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
|
Loading…
Reference in New Issue
Block a user