Fully wrap the command encoder (#1572)

* fully wrap the command encoder

* use consistent style + fix extensions
This commit is contained in:
Awni Hannun 2024-11-08 11:50:21 -08:00 committed by GitHub
parent 59247c2b62
commit 9f0d5c12fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 469 additions and 484 deletions

View File

@ -494,7 +494,7 @@ below.
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@ -509,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder.set_bytes(alpha_, 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed // threads in any given threadgroup is not higher than the max allowed
@ -530,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@ -257,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@ -272,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder.set_bytes(alpha_, 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim if needed // Encode shape, strides and ndim if needed
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
} }
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
@ -295,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
#else // Metal is not available #else // Metal is not available

View File

@ -242,6 +242,9 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
residency_set_.erase(buf); residency_set_.erase(buf);
active_memory_ -= buf->length(); active_memory_ -= buf->length();

View File

@ -92,7 +92,7 @@ void binary_op_gpu_inplace(
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// - If a is donated it goes to the first output // - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated // - If b is donated it goes to the first output if a was not donated
@ -117,19 +117,15 @@ void binary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) { if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++); compute_encoder.set_vector_bytes(shape, arg_idx++);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides_a, arg_idx++);
strides_a.data(), ndim * sizeof(size_t), arg_idx++); compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder->setBytes( compute_encoder.set_bytes<int>(ndim, arg_idx++);
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else { } else {
// The shape is implicit in the grid for <= 3D // The shape is implicit in the grid for <= 3D
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides_a, arg_idx++);
strides_a.data(), ndim * sizeof(size_t), arg_idx++); compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
} }
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
@ -137,7 +133,7 @@ void binary_op_gpu_inplace(
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
@ -147,7 +143,7 @@ void binary_op_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -373,7 +373,7 @@ void Compiled::eval_gpu(
} }
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Put the inputs in // Put the inputs in
int cnt = 0; int cnt = 0;
@ -394,8 +394,7 @@ void Compiled::eval_gpu(
} }
} }
if (!in_strides.empty()) { if (!in_strides.empty()) {
compute_encoder->setBytes( compute_encoder.set_vector_bytes(in_strides, cnt++);
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
} }
compiled_allocate_outputs( compiled_allocate_outputs(
@ -408,14 +407,13 @@ void Compiled::eval_gpu(
// Put the output shape and strides in // Put the output shape and strides in
if (!contiguous) { if (!contiguous) {
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides[0], cnt++);
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++); compute_encoder.set_vector_bytes(shape, cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
} }
// Put the number of dims in if it is dynamic // Put the number of dims in if it is dynamic
if (dynamic) { if (dynamic) {
compute_encoder->setBytes(&ndim, sizeof(int), cnt++); compute_encoder.set_bytes(ndim, cnt++);
} }
// Launch the kernel // Launch the kernel
@ -427,7 +425,7 @@ void Compiled::eval_gpu(
MTL::Size grid_dims = use_2d MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@ -445,7 +443,7 @@ void Compiled::eval_gpu(
} }
auto group_dims = get_block_dims(dim0, dim1, rest, pow2); auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -44,12 +44,12 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel // Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64); int tgp_x = std::min(conv_params.C, 64);
@ -60,7 +60,7 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size( MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// Reshape weight // Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N}; std::vector<int> wt_reshape{implicit_K, implicit_N};
@ -122,12 +122,12 @@ void explicit_gemm_conv_group_ND_gpu(
<< N; << N;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel // Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64); int tgp_x = std::min(conv_params.C, 64);
@ -138,7 +138,7 @@ void explicit_gemm_conv_group_ND_gpu(
MTL::Size grid_dims = MTL::Size( MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks // Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups. // of channel groups.
@ -237,7 +237,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@ -252,8 +252,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1); compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void implicit_gemm_conv_2D_gpu( void implicit_gemm_conv_2D_gpu(
@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
wn, wn,
n_channel_specialization, n_channel_specialization,
small_filter); small_filter);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@ -368,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder.set_bytes(conv_params, 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); compute_encoder.set_bytes(gemm_params, 4);
// Launch kernel // Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void implicit_gemm_conv_2D_general_gpu( void implicit_gemm_conv_2D_general_gpu(
@ -506,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@ -523,17 +523,15 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder.set_bytes(conv_params, 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); compute_encoder.set_bytes(gemm_params, 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5); compute_encoder.set_bytes(jump_params, 5);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(base_h, 6);
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6); compute_encoder.set_vector_bytes(base_w, 7);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel // Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void winograd_conv_2D_gpu( void winograd_conv_2D_gpu(
@ -622,18 +620,18 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0); compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1); compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2); compute_encoder.set_bytes(C_c, 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3); compute_encoder.set_bytes(O_c, 3);
MTL::Size group_dims = MTL::Size(32, bo, 1); MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
// Do input transform // Do input transform
@ -650,18 +648,17 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1); compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes( compute_encoder.set_bytes(conv_params_updated, 2);
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
// Do batched gemm // Do batched gemm
@ -698,18 +695,17 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes( compute_encoder.set_bytes(conv_params_updated, 2);
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
} }

View File

@ -111,7 +111,7 @@ void copy_gpu_inplace(
auto kernel = get_copy_kernel(d, kernel_name, in, out); auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
bool donate_in = in.data_shared_ptr() == nullptr; bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype()); inp_offset *= size_of(in.dtype());
@ -125,11 +125,11 @@ void copy_gpu_inplace(
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()}; std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()}; std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2); compute_encoder.set_vector_bytes(shape, ndim, 2);
} }
set_vector_bytes(compute_encoder, strides_in, ndim, 3); compute_encoder.set_vector_bytes(strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
set_vector_bytes(compute_encoder, strides_out, ndim, 4); compute_encoder.set_vector_bytes(strides_out, ndim, 4);
} }
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@ -141,7 +141,7 @@ void copy_gpu_inplace(
int rest = data_size / (dim0 * dim1); int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) { if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5); compute_encoder.set_bytes(ndim, 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} }
@ -152,7 +152,7 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
@ -161,7 +161,7 @@ void copy_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }
@ -199,7 +199,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
type_to_name(val) + type_to_name(out); type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out); auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@ -212,7 +212,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
d.get_library(lib_name, [this] { return metal::utils() + source_; }); d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib); auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int index = 0; int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) { for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i]; const array& in = checked_inputs[i];
@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
if (in.ndim() > 0) { if (in.ndim() > 0) {
int ndim = in.ndim(); int ndim = in.ndim();
if (shape_info.shape) { if (shape_info.shape) {
set_vector_bytes(compute_encoder, in.shape(), ndim, index); compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++; index++;
} }
if (shape_info.strides) { if (shape_info.strides) {
set_vector_bytes(compute_encoder, in.strides(), ndim, index); compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++; index++;
} }
if (shape_info.ndim) { if (shape_info.ndim) {
compute_encoder->setBytes(&ndim, sizeof(int), index); compute_encoder.set_bytes(ndim, index);
index++; index++;
} }
} }
@ -75,7 +75,7 @@ void CustomKernel::eval_gpu(
MTL::Size group_dims = MTL::Size(tx, ty, tz); MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_; const auto [gx, gy, gz] = grid_;
MTL::Size grid_dims = MTL::Size(gx, gy, gz); MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@ -171,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() {
next_outputs_.clear(); next_outputs_.clear();
} }
void CommandEncoder::dispatchThreadgroups( void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims); enc_->dispatchThreadgroups(grid_dims, group_dims);
} }
void CommandEncoder::dispatchThreads( void CommandEncoder::dispatch_threads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
@ -298,7 +298,7 @@ void Device::end_encoding(int index) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again. // If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) { if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence); enc.wait_for_fence(it->second->fence);
waiting_on.insert(it->second); waiting_on.insert(it->second);
} }
} }
@ -307,7 +307,7 @@ void Device::end_encoding(int index) {
stream.outputs[out] = stream.fence; stream.outputs[out] = stream.fence;
} }
} }
enc->updateFence(stream.fence->fence); enc.update_fence(stream.fence->fence);
stream.buffer->addCompletedHandler( stream.buffer->addCompletedHandler(
[&stream, [&stream,
waiting_on = std::move(waiting_on), waiting_on = std::move(waiting_on),

View File

@ -58,16 +58,43 @@ struct CommandEncoder {
CommandEncoder& enc; CommandEncoder& enc;
}; };
MTL::ComputeCommandEncoder* operator->() {
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier(); void maybeInsertBarrier();
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
}
void wait_for_fence(MTL::Fence* fence) {
enc_->waitForFence(fence);
}
void update_fence(MTL::Fence* fence) {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
template <typename T>
void set_bytes(const T* v, int n, int idx) {
return enc_->setBytes(v, n * sizeof(T), idx);
}
template <typename T>
void set_bytes(const T& v, int idx) {
return enc_->setBytes(&v, sizeof(T), idx);
}
ConcurrentContext start_concurrent() { ConcurrentContext start_concurrent() {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }

View File

@ -699,7 +699,7 @@ void fft_op(
auto kernel = auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def); get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@ -711,9 +711,9 @@ void fft_op(
compute_encoder.set_input_array(w_q, 2); // w_q compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4); compute_encoder.set_bytes(n, 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5); compute_encoder.set_bytes(plan.bluestein_n, 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); compute_encoder.set_bytes(total_batch_size, 6);
} else if (plan.rader_n > 1) { } else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s); auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q); copies.push_back(b_q);
@ -723,22 +723,22 @@ void fft_op(
compute_encoder.set_input_array(b_q, 2); compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3); compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4); compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5); compute_encoder.set_bytes(n, 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7); compute_encoder.set_bytes(plan.rader_n, 7);
} else if (four_step_params.required) { } else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2); compute_encoder.set_bytes(four_step_params.n1, 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3); compute_encoder.set_bytes(four_step_params.n2, 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4); compute_encoder.set_bytes(total_batch_size, 4);
} else { } else {
compute_encoder->setBytes(&n, sizeof(int), 2); compute_encoder.set_bytes(n, 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3); compute_encoder.set_bytes(total_batch_size, 3);
} }
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims = auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2); compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1); MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
}; };
if (m > 1) { if (m > 1) {

View File

@ -87,7 +87,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t slice_size = 1; size_t slice_size = 1;
for (auto s : slice_sizes_) { for (auto s : slice_sizes_) {
@ -131,20 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
// Set source info // Set source info
set_vector_bytes(compute_encoder, src.shape(), 2); compute_encoder.set_vector_bytes(src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3); compute_encoder.set_vector_bytes(src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4); compute_encoder.set_bytes(ndim, 4);
set_vector_bytes(compute_encoder, slice_sizes_, 5); compute_encoder.set_vector_bytes(slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6); compute_encoder.set_vector_bytes(axes_, 6);
// Set index info // Set index info
// //
// We don't need to check for empty idx_shapes because gather has a // We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization // idx_ndim == 0 specialization
set_vector_bytes(compute_encoder, idx_shapes, 7); compute_encoder.set_vector_bytes(idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8); compute_encoder.set_vector_bytes(idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9); compute_encoder.set_vector_bytes(idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10); compute_encoder.set_bytes(idx_ndim, 10);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
@ -152,7 +152,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
// Launch grid // Launch grid
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -289,7 +289,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t nthreads = upd.size(); size_t nthreads = upd.size();
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set all the buffers // Set all the buffers
compute_encoder.set_input_array(upd, 1); compute_encoder.set_input_array(upd, 1);
@ -323,14 +323,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3); compute_encoder.set_bytes(shape_, 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4); compute_encoder.set_bytes(stride_, 4);
} else { } else {
set_vector_bytes(compute_encoder, upd.shape(), 3); compute_encoder.set_vector_bytes(upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4); compute_encoder.set_vector_bytes(upd.strides(), 4);
} }
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); compute_encoder.set_bytes(upd_ndim, 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); compute_encoder.set_bytes(upd_size, 6);
// Set output info // Set output info
size_t out_ndim = out.ndim(); size_t out_ndim = out.ndim();
@ -338,14 +338,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7); compute_encoder.set_bytes(shape_, 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8); compute_encoder.set_bytes(stride_, 8);
} else { } else {
set_vector_bytes(compute_encoder, out.shape(), 7); compute_encoder.set_vector_bytes(out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8); compute_encoder.set_vector_bytes(out.strides(), 8);
} }
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder.set_bytes(out_ndim, 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); compute_encoder.set_vector_bytes(axes_, 10);
// Set index info // Set index info
if (idx_ndim == 0) { if (idx_ndim == 0) {
@ -355,11 +355,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_strides.push_back(0); idx_strides.push_back(0);
idx_contigs.push_back(false); idx_contigs.push_back(false);
} }
set_vector_bytes(compute_encoder, idx_shapes, 11); compute_encoder.set_vector_bytes(idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12); compute_encoder.set_vector_bytes(idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13); compute_encoder.set_vector_bytes(idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14); compute_encoder.set_bytes(idx_ndim, 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15); compute_encoder.set_bytes(idx_size, 15);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
@ -375,7 +375,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads"); throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
} }
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1); MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -17,12 +17,15 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, constant float& eps,
constant uint& axis_size, constant uint& axis_size,
constant uint& w_stride, constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
@ -84,13 +87,15 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, constant float& eps,
constant uint& axis_size, constant uint& axis_size,
constant uint& w_stride, constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
@ -376,8 +381,6 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
@ -407,8 +410,6 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \

View File

@ -249,7 +249,7 @@ void steel_matmul_regular(
wm, wm,
wn); wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@ -288,12 +288,12 @@ void steel_matmul_regular(
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4); compute_encoder.set_bytes(params, 4);
set_vector_bytes(compute_encoder, batch_shape, 6); compute_encoder.set_vector_bytes(batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Record copies // Record copies
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@ -390,7 +390,7 @@ void steel_matmul(
wn, wn,
mn_aligned, mn_aligned,
k_aligned); k_aligned);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
@ -416,34 +416,30 @@ void steel_matmul(
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2); compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3); compute_encoder.set_bytes(params, 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel // Do accum kernel
{ {
auto c_split_buf =
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split); type_to_name(C_split);
auto kernel = get_steel_gemm_splitk_accum_kernel( auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, false); d, kernel_name, C_split, out, false);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0); compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder->setBytes(&N, sizeof(int), 4); compute_encoder.set_bytes(N, 4);
// Launch enough thread groups for each output // Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@ -625,7 +621,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@ -635,16 +631,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1); compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6); compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); compute_encoder.set_bytes(batch_ndim, 9);
set_vector_bytes(compute_encoder, batch_shape, 10); compute_encoder.set_vector_bytes(batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12); compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@ -822,7 +818,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@ -833,23 +829,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2); compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6); compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&alpha_, sizeof(float), 7); compute_encoder.set_bytes(alpha_, 7);
compute_encoder->setBytes(&beta_, sizeof(float), 8); compute_encoder.set_bytes(beta_, 8);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); compute_encoder.set_bytes(batch_ndim, 9);
set_vector_bytes(compute_encoder, batch_shape, 10); compute_encoder.set_vector_bytes(batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12); compute_encoder.set_vector_bytes(batch_strides_mat, 12);
set_vector_bytes(compute_encoder, C_batch_stride, 13); compute_encoder.set_vector_bytes(C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1]; int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14); compute_encoder.set_bytes(bias_stride, 14);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@ -907,7 +903,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
mn_aligned, mn_aligned,
k_aligned); k_aligned);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
@ -933,8 +929,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2); compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3); compute_encoder.set_bytes(params, 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel // Do accum kernel
{ {
@ -943,25 +939,25 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = get_steel_gemm_splitk_accum_kernel( auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true); d, kernel_name, C_split, out, true);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0); compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder->setBytes(&N, sizeof(int), 4); compute_encoder.set_bytes(N, 4);
compute_encoder.set_input_array(c, 5); compute_encoder.set_input_array(c, 5);
compute_encoder->setBytes(&ldc, sizeof(int), 6); compute_encoder.set_bytes(ldc, 6);
compute_encoder->setBytes(&fdc, sizeof(int), 7); compute_encoder.set_bytes(fdc, 7);
compute_encoder->setBytes(&alpha_, sizeof(float), 8); compute_encoder.set_bytes(alpha_, 8);
compute_encoder->setBytes(&beta_, sizeof(float), 9); compute_encoder.set_bytes(beta_, 9);
// Launch enough thread groups for each output // Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@ -1032,7 +1028,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm, wm,
wn); wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
@ -1083,13 +1079,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2); compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4); compute_encoder.set_bytes(gemm_params, 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5); compute_encoder.set_bytes(params, 5);
set_vector_bytes(compute_encoder, batch_shape, 6); compute_encoder.set_vector_bytes(batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
@ -1304,7 +1300,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
contiguous_kernel); contiguous_kernel);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@ -1372,18 +1368,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1); compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6); compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); compute_encoder.set_bytes(batch_ndim, 9);
set_vector_bytes(compute_encoder, batch_shape, 10); compute_encoder.set_vector_bytes(batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11); compute_encoder.set_vector_bytes(batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12); compute_encoder.set_vector_bytes(batch_strides_mat, 12);
set_vector_bytes(compute_encoder, mask_strides, 23); compute_encoder.set_vector_bytes(mask_strides, 23);
set_vector_bytes(compute_encoder, mask_batch_strides, 24); compute_encoder.set_vector_bytes(mask_batch_strides, 24);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@ -1423,7 +1419,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wn, wn,
mn_aligned, mn_aligned,
k_aligned); k_aligned);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@ -1486,14 +1482,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4); compute_encoder.set_bytes(params, 4);
set_vector_bytes(compute_encoder, batch_shape, 6); compute_encoder.set_vector_bytes(batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.set_vector_bytes(batch_strides, 7);
set_vector_bytes(compute_encoder, mask_strides, 13); compute_encoder.set_vector_bytes(mask_strides, 13);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
@ -1687,7 +1683,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size group_dims = MTL::Size(32, bn, bm);
@ -1697,28 +1693,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1); compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6); compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); compute_encoder.set_bytes(batch_ndim, 9);
set_vector_bytes(compute_encoder, batch_shape, 10); compute_encoder.set_vector_bytes(batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides, 11); compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size(); int batch_ndim_vec = batch_shape_vec.size();
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12); compute_encoder.set_bytes(batch_ndim_vec, 12);
set_vector_bytes(compute_encoder, batch_shape_vec, 13); compute_encoder.set_vector_bytes(batch_shape_vec, 13);
set_vector_bytes(compute_encoder, batch_strides_vec, 14); compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size(); int batch_ndim_mat = batch_shape_mat.size();
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15); compute_encoder.set_bytes(batch_ndim_mat, 15);
set_vector_bytes(compute_encoder, batch_shape_mat, 16); compute_encoder.set_vector_bytes(batch_shape_mat, 16);
set_vector_bytes(compute_encoder, batch_strides_mat, 17); compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
return; return;
@ -1788,7 +1784,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm, wm,
wn); wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@ -1827,10 +1823,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1); compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4); compute_encoder.set_bytes(params, 4);
set_vector_bytes(compute_encoder, batch_shape, 6); compute_encoder.set_vector_bytes(batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7); compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11); compute_encoder.set_input_array(rhs_indices, 11);
@ -1845,11 +1841,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
operand_batch_ndim.push_back(0); operand_batch_ndim.push_back(0);
set_vector_bytes(compute_encoder, operand_shape, 13); compute_encoder.set_vector_bytes(operand_shape, 13);
set_vector_bytes(compute_encoder, operand_strides, 14); compute_encoder.set_vector_bytes(operand_strides, 14);
set_vector_bytes(compute_encoder, operand_batch_ndim, 15); compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@ -78,18 +78,15 @@ void RMSNorm::eval_gpu(
} }
uint32_t w_stride = w.strides()[0]; uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0); x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&eps_, sizeof(float), 3); compute_encoder.set_bytes(eps_, 3);
compute_encoder->setBytes(&axis_size, sizeof(int), 4); compute_encoder.set_bytes(axis_size, 4);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5); compute_encoder.set_bytes(w_stride, 5);
compute_encoder->setThreadgroupMemoryLength( compute_encoder.dispatch_threads(grid_dims, group_dims);
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@ -183,16 +180,16 @@ void RMSNormVJP::eval_gpu(
} }
uint32_t w_stride = w.strides()[0]; uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_output_array(gw_temp, 4);
compute_encoder->setBytes(&eps_, sizeof(float), 5); compute_encoder.set_bytes(eps_, 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6); compute_encoder.set_bytes(axis_size, 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); compute_encoder.set_bytes(w_stride, 7);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
ReductionPlan plan( ReductionPlan plan(
@ -273,17 +270,17 @@ void LayerNorm::eval_gpu(
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0); x.data_shared_ptr() == nullptr ? out : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(b, 2); compute_encoder.set_input_array(b, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&eps_, sizeof(float), 4); compute_encoder.set_bytes(eps_, 4);
compute_encoder->setBytes(&axis_size, sizeof(int), 5); compute_encoder.set_bytes(axis_size, 5);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6); compute_encoder.set_bytes(w_stride, 6);
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); compute_encoder.set_bytes(b_stride, 7);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@ -395,16 +392,16 @@ void LayerNormVJP::eval_gpu(
} }
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4); compute_encoder.set_output_array(gw_temp, 4);
compute_encoder->setBytes(&eps_, sizeof(float), 5); compute_encoder.set_bytes(eps_, 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6); compute_encoder.set_bytes(axis_size, 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); compute_encoder.set_bytes(w_stride, 7);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
if (gw.ndim() == 1 && gw.size() == axis_size) { if (gw.ndim() == 1 && gw.size() == axis_size) {

View File

@ -17,10 +17,10 @@
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>
void arange_set_scalars(T start, T next, CommandEncoder& enc) { void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc->setBytes(&start, sizeof(T), 0); enc.set_bytes(start, 0);
T step = next - start; T step = next - start;
enc->setBytes(&step, sizeof(T), 1); enc.set_bytes(step, 1);
} }
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -37,7 +37,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size( MTL::Size group_dims = MTL::Size(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
switch (out.dtype()) { switch (out.dtype()) {
case bool_: // unsupported case bool_: // unsupported
@ -80,7 +80,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -129,25 +129,25 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t n_threads = out.size() * thread_group_size; size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
if (ndim == 0) { if (ndim == 0) {
// Pass place holders so metal doesn't complain // Pass place holders so metal doesn't complain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 2); compute_encoder.set_bytes(shape_, 2);
compute_encoder->setBytes(&stride_, sizeof(size_t), 3); compute_encoder.set_bytes(stride_, 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4); compute_encoder.set_bytes(stride_, 4);
} else { } else {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); compute_encoder.set_vector_bytes(shape, 2);
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3); compute_encoder.set_vector_bytes(in_strides, 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); compute_encoder.set_vector_bytes(out_strides, 4);
} }
compute_encoder->setBytes(&ndim, sizeof(size_t), 5); compute_encoder.set_bytes(ndim, 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); compute_encoder.set_bytes(axis_stride, 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); compute_encoder.set_bytes(axis_size, 7);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }
@ -275,22 +275,20 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1); MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(keys, 0); compute_encoder.set_input_array(keys, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&odd, sizeof(bool), 2); compute_encoder.set_bytes(odd, 2);
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3); compute_encoder.set_bytes(bytes_per_key, 3);
if (!keys.flags().row_contiguous) { if (!keys.flags().row_contiguous) {
int ndim = keys.ndim(); int ndim = keys.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4); compute_encoder.set_bytes(ndim, 4);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(keys.shape(), 5);
keys.shape().data(), keys.ndim() * sizeof(int), 5); compute_encoder.set_vector_bytes(keys.strides(), 6);
compute_encoder->setBytes(
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
} }
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) { void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -101,31 +101,31 @@ void launch_qmm(
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3); compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4); compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder.set_bytes(D, 5);
compute_encoder->setBytes(&O, sizeof(int), 6); compute_encoder.set_bytes(O, 6);
int offset = 7; int offset = 7;
if (matrix) { if (matrix) {
compute_encoder->setBytes(&B, sizeof(int), 7); compute_encoder.set_bytes(B, 7);
offset += 1; offset += 1;
} }
if (batched || gather) { if (batched || gather) {
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset); compute_encoder.set_bytes(x_batch_ndims, offset);
set_vector_bytes(compute_encoder, x_shape, offset + 1); compute_encoder.set_vector_bytes(x_shape, offset + 1);
set_vector_bytes(compute_encoder, x_strides, offset + 2); compute_encoder.set_vector_bytes(x_strides, offset + 2);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3); compute_encoder.set_bytes(w_batch_ndims, offset + 3);
set_vector_bytes(compute_encoder, w_shape, offset + 4); compute_encoder.set_vector_bytes(w_shape, offset + 4);
set_vector_bytes(compute_encoder, w_strides, offset + 5); compute_encoder.set_vector_bytes(w_strides, offset + 5);
set_vector_bytes(compute_encoder, s_strides, offset + 6); compute_encoder.set_vector_bytes(s_strides, offset + 6);
set_vector_bytes(compute_encoder, b_strides, offset + 7); compute_encoder.set_vector_bytes(b_strides, offset + 7);
} }
if (gather) { if (gather) {
auto& lhs_indices = inputs[4]; auto& lhs_indices = inputs[4];
@ -137,15 +137,15 @@ void launch_qmm(
auto& lhs_strides = lhs_indices.strides(); auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides(); auto& rhs_strides = rhs_indices.strides();
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8); compute_encoder.set_bytes(batch_ndims, offset + 8);
set_vector_bytes(compute_encoder, batch_shape, offset + 9); compute_encoder.set_vector_bytes(batch_shape, offset + 9);
compute_encoder.set_input_array(lhs_indices, offset + 10); compute_encoder.set_input_array(lhs_indices, offset + 10);
compute_encoder.set_input_array(rhs_indices, offset + 11); compute_encoder.set_input_array(rhs_indices, offset + 11);
set_vector_bytes(compute_encoder, lhs_strides, offset + 12); compute_encoder.set_vector_bytes(lhs_strides, offset + 12);
set_vector_bytes(compute_encoder, rhs_strides, offset + 13); compute_encoder.set_vector_bytes(rhs_strides, offset + 13);
} }
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
@ -236,27 +236,27 @@ void qvm_split_k(
// Encode and dispatch kernel // Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3); compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4); compute_encoder.set_output_array(intermediate, 4);
compute_encoder->setBytes(&split_D, sizeof(int), 5); compute_encoder.set_bytes(split_D, 5);
compute_encoder->setBytes(&O, sizeof(int), 6); compute_encoder.set_bytes(O, 6);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7); compute_encoder.set_bytes(x_batch_ndims, 7);
set_vector_bytes(compute_encoder, x_shape, 8); compute_encoder.set_vector_bytes(x_shape, 8);
set_vector_bytes(compute_encoder, x_strides, 9); compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10); compute_encoder.set_bytes(w_batch_ndims, 10);
set_vector_bytes(compute_encoder, w_shape, 11); compute_encoder.set_vector_bytes(w_shape, 11);
set_vector_bytes(compute_encoder, w_strides, 12); compute_encoder.set_vector_bytes(w_strides, 12);
set_vector_bytes(compute_encoder, s_strides, 13); compute_encoder.set_vector_bytes(s_strides, 13);
set_vector_bytes(compute_encoder, b_strides, 14); compute_encoder.set_vector_bytes(b_strides, 14);
compute_encoder->setBytes(&final_block_size, sizeof(int), 15); compute_encoder.set_bytes(final_block_size, 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3; int axis = intermediate.ndim() - 3;
@ -447,7 +447,7 @@ void fast::AffineQuantize::eval_gpu(
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_); kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Treat uint32 as uint8 in kernel // Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4; constexpr int uint8_per_uint32 = 4;
@ -471,7 +471,7 @@ void fast::AffineQuantize::eval_gpu(
} }
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@ -67,17 +67,14 @@ struct RowReduceArgs {
strides.push_back(0); strides.push_back(0);
} }
compute_encoder->setBytes(&row_size, sizeof(size_t), 2); compute_encoder.set_bytes(row_size, 2);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 3); compute_encoder.set_bytes(non_row_reductions, 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); compute_encoder.set_vector_bytes(shape, 4);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides, 5);
strides.data(), strides.size() * sizeof(size_t), 5); compute_encoder.set_bytes(ndim, 6);
compute_encoder->setBytes(&ndim, sizeof(int), 6); compute_encoder.set_vector_bytes(reduce_shape, 7);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(reduce_strides, 8);
reduce_shape.data(), reduce_shape.size() * sizeof(int), 7); compute_encoder.set_bytes(reduce_ndim, 9);
compute_encoder->setBytes(
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
if (reduce_ndim == 0) { if (reduce_ndim == 0) {
reduce_shape.pop_back(); reduce_shape.pop_back();
@ -166,18 +163,15 @@ struct ColReduceArgs {
strides.push_back(0); strides.push_back(0);
} }
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder.set_bytes(reduction_size, 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder.set_bytes(reduction_stride, 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); compute_encoder.set_vector_bytes(shape, 4);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides, 5);
strides.data(), strides.size() * sizeof(size_t), 5); compute_encoder.set_bytes(ndim, 6);
compute_encoder->setBytes(&ndim, sizeof(int), 6); compute_encoder.set_vector_bytes(reduce_shape, 7);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(reduce_strides, 8);
reduce_shape.data(), reduce_shape.size() * sizeof(int), 7); compute_encoder.set_bytes(reduce_ndim, 9);
compute_encoder->setBytes( compute_encoder.set_bytes(non_col_reductions, 10);
reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9);
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 10);
if (reduce_ndim == 0) { if (reduce_ndim == 0) {
reduce_shape.pop_back(); reduce_shape.pop_back();
@ -256,9 +250,9 @@ void init_reduce(
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_output_array(out, 0); compute_encoder.set_output_array(out, 0);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void all_reduce_dispatch( void all_reduce_dispatch(
@ -273,7 +267,7 @@ void all_reduce_dispatch(
const std::string func_name = "all_reduce"; const std::string func_name = "all_reduce";
kname << func_name << "_" << op_name << type_to_name(in); kname << func_name << "_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t in_size = in.size(); size_t in_size = in.size();
@ -285,9 +279,9 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2); compute_encoder.set_bytes(in_size, 2);
compute_encoder->setBytes(&in_size, sizeof(size_t), 3); compute_encoder.set_bytes(in_size, 3);
compute_encoder.dispatchThreads(grid_dims, grid_dims); compute_encoder.dispatch_threads(grid_dims, grid_dims);
} }
// We need multiple threadgroups so we 'll do it in 2 passes. // We need multiple threadgroups so we 'll do it in 2 passes.
@ -319,24 +313,24 @@ void all_reduce_dispatch(
MTL::Size group_dims(threadgroup_size, 1, 1); MTL::Size group_dims(threadgroup_size, 1, 1);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_output_array(intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2); compute_encoder.set_bytes(in_size, 2);
compute_encoder->setBytes(&row_size, sizeof(size_t), 3); compute_encoder.set_bytes(row_size, 3);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// 2nd pass // 2nd pass
std::ostringstream kname_2nd_pass; std::ostringstream kname_2nd_pass;
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate); kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass = get_reduce_kernel( auto kernel_2nd_pass = get_reduce_kernel(
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out); d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
compute_encoder->setComputePipelineState(kernel_2nd_pass); compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);
size_t intermediate_size = n_rows; size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); compute_encoder.set_bytes(intermediate_size, 2);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 3); compute_encoder.set_bytes(intermediate_size, 3);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }
@ -355,7 +349,7 @@ void row_reduce_small(
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid dims // Figure out the grid dims
MTL::Size grid_dims; MTL::Size grid_dims;
@ -375,7 +369,7 @@ void row_reduce_small(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void row_reduce_simple( void row_reduce_simple(
@ -391,7 +385,7 @@ void row_reduce_simple(
const std::string func_name = "row_reduce_simple"; const std::string func_name = "row_reduce_simple";
kname << func_name << "_" << op_name << type_to_name(in); kname << func_name << "_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid dims // Figure out the grid dims
size_t row_size = args.row_size; size_t row_size = args.row_size;
@ -410,9 +404,9 @@ void row_reduce_simple(
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&row_size, sizeof(size_t), 2); compute_encoder.set_bytes(row_size, 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3); compute_encoder.set_bytes(out_size, 3);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void row_reduce_looped( void row_reduce_looped(
@ -430,7 +424,7 @@ void row_reduce_looped(
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid // Figure out the grid
auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
@ -443,7 +437,7 @@ void row_reduce_looped(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void row_reduce_general_dispatch( void row_reduce_general_dispatch(
@ -495,7 +489,7 @@ void strided_reduce_small(
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
const int n_reads = 4; const int n_reads = 4;
size_t reduction_stride_blocks = size_t reduction_stride_blocks =
@ -517,7 +511,7 @@ void strided_reduce_small(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void strided_reduce_longcolumn( void strided_reduce_longcolumn(
@ -568,14 +562,14 @@ void strided_reduce_longcolumn(
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11); compute_encoder.set_bytes(out_size, 11);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims // Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate); ColReduceArgs second_args(intermediate);
@ -599,12 +593,12 @@ void strided_reduce_longcolumn(
1, 1,
32, 32,
32); 32);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder); second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void strided_reduce_looped( void strided_reduce_looped(
@ -639,13 +633,13 @@ void strided_reduce_looped(
<< op_name << type_to_name(in); << op_name << type_to_name(in);
auto kernel = auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void strided_reduce_2pass( void strided_reduce_2pass(
@ -692,14 +686,14 @@ void strided_reduce_2pass(
<< op_name << type_to_name(in); << op_name << type_to_name(in);
auto kernel = auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1); compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder); args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11); compute_encoder.set_bytes(out_size, 11);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims // Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate); ColReduceArgs second_args(intermediate);
@ -721,12 +715,12 @@ void strided_reduce_2pass(
1, 1,
32, 32,
32); 32);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder); second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void strided_reduce_general_dispatch( void strided_reduce_general_dispatch(

View File

@ -75,24 +75,24 @@ void RoPE::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_); float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&offset_, sizeof(int), 2); compute_encoder.set_bytes(offset_, 2);
compute_encoder->setBytes(&scale_, sizeof(float), 3); compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size; size_t n_batch = in.size() / mat_size;
MTL::Size group_dims; MTL::Size group_dims;
MTL::Size grid_dims; MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder->setBytes(out_strides, sizeof(size_t), 4); compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, n_batch, 1);
} else { } else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4); compute_encoder.set_bytes(strides, 3, 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5); compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6); compute_encoder.set_bytes(n_batch, 6);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2); uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
@ -104,11 +104,11 @@ void RoPE::eval_gpu(
auto& freqs = inputs[1]; auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10); compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0]; auto freq_stride = freqs.strides()[0];
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11); compute_encoder.set_bytes(freq_stride, 11);
} else { } else {
compute_encoder->setBytes(&base, sizeof(float), 10); compute_encoder.set_bytes(base, 10);
} }
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -59,7 +59,7 @@ void sdpa_full_self_attention_metal(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_self_attention.str()); auto kernel = d.get_kernel(kname_self_attention.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
uint hidden_dim = q.shape(-1); uint hidden_dim = q.shape(-1);
uint qseq = q.shape(-2); uint qseq = q.shape(-2);
@ -129,17 +129,14 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_input_array(v, 2); compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(MLXFastAttentionParams), 4); compute_encoder.set_bytes(params, 4);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(batch_shape, 6);
batch_shape.data(), sizeof(int) * batch_shape.size(), 6); compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
MTL::Size group_dims = MTL::Size(32, wm, wn); MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void sdpa_vector( void sdpa_vector(
@ -170,21 +167,21 @@ void sdpa_vector(
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname); auto kernel = d.get_kernel(kname);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments // Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2); compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3); compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4); compute_encoder.set_bytes(gqa_factor, 4);
compute_encoder->setBytes(&N, sizeof(int), 5); compute_encoder.set_bytes(N, 5);
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6); compute_encoder.set_bytes(k_stride, 6);
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7); compute_encoder.set_bytes(v_stride, 7);
compute_encoder->setBytes(&scale, sizeof(float), 8); compute_encoder.set_bytes(scale, 8);
// Launch // Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
} // namespace } // namespace

View File

@ -68,12 +68,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
if (contiguous) { if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_); size_t size = in.shape(axis_);
compute_encoder->setBytes(&size, sizeof(size_t), 2); compute_encoder.set_bytes(size, 2);
// Compute the thread grid // Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2; int n_reads = (in.itemsize() <= 4) ? 4 : 2;
@ -95,10 +95,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims( MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1); MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@ -107,9 +107,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int bm = 32; int bm = 32;
int bn = 32; int bn = 32;
size_t stride_blocks = (stride + bn - 1) / bn; size_t stride_blocks = (stride + bn - 1) / bn;
compute_encoder->setBytes(&size, sizeof(size_t), 2); compute_encoder.set_bytes(size, 2);
compute_encoder->setBytes(&stride, sizeof(size_t), 3); compute_encoder.set_bytes(stride, 3);
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4); compute_encoder.set_bytes(stride_blocks, 4);
// Compute the thread grid // Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2; int n_reads = (in.itemsize() <= 4) ? 4 : 2;
@ -125,7 +125,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims( MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1); MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@ -81,12 +81,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
group_dims = MTL::Size(threadgroup_size, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1);
} }
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@ -68,29 +68,29 @@ void single_block_sort(
// Prepare command encoder // Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set inputs // Set inputs
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); compute_encoder.set_bytes(size_sorted_axis, 2);
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3); compute_encoder.set_bytes(in_stride_sorted_axis, 3);
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4); compute_encoder.set_bytes(out_stride_sorted_axis, 4);
if (contiguous) { if (contiguous) {
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5); compute_encoder.set_bytes(in_stride_segment_axis, 5);
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6); compute_encoder.set_bytes(out_stride_segment_axis, 6);
} else { } else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 5); compute_encoder.set_bytes(nc_dim, 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6); compute_encoder.set_vector_bytes(nc_shape, 6);
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7); compute_encoder.set_vector_bytes(in_nc_str, 7);
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8); compute_encoder.set_vector_bytes(out_nc_str, 8);
} }
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void multi_block_sort( void multi_block_sort(
@ -152,22 +152,21 @@ void multi_block_sort(
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel = auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(dev_vals_0, 1); compute_encoder.set_output_array(dev_vals_0, 1);
compute_encoder.set_output_array(dev_idxs_0, 2); compute_encoder.set_output_array(dev_idxs_0, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder.set_bytes(size_sorted_axis, 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4); compute_encoder.set_bytes(stride_sorted_axis, 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5); compute_encoder.set_bytes(nc_dim, 5);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(nc_shape, 6);
nc_shape.data(), nc_shape.size() * sizeof(int), 6); compute_encoder.set_vector_bytes(nc_str, 7);
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
// Do merges // Do merges
@ -194,19 +193,19 @@ void multi_block_sort(
auto kernel = auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_output_array(block_partitions, 0); compute_encoder.set_output_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder.set_bytes(size_sorted_axis, 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4); compute_encoder.set_bytes(merge_tiles, 4);
compute_encoder->setBytes(&n_blocks, sizeof(int), 5); compute_encoder.set_bytes(n_blocks, 5);
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1); MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
// Do merge // Do merge
@ -217,21 +216,21 @@ void multi_block_sort(
auto kernel = auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(block_partitions, 0); compute_encoder.set_input_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder.set_output_array(dev_vals_out, 3); compute_encoder.set_output_array(dev_vals_out, 3);
compute_encoder.set_output_array(dev_idxs_out, 4); compute_encoder.set_output_array(dev_idxs_out, 4);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5); compute_encoder.set_bytes(size_sorted_axis, 5);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6); compute_encoder.set_bytes(merge_tiles, 6);
compute_encoder->setBytes(&n_blocks, sizeof(int), 7); compute_encoder.set_bytes(n_blocks, 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
} }

View File

@ -63,7 +63,7 @@ void ternary_op_gpu_inplace(
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op); auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
bool donate_a = a.data_shared_ptr() == nullptr; bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr;
bool donate_c = c.data_shared_ptr() == nullptr; bool donate_c = c.data_shared_ptr() == nullptr;
@ -80,18 +80,18 @@ void ternary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) { if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); compute_encoder.set_vector_bytes(shape, 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); compute_encoder.set_vector_bytes(strides_a, 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(strides_b, 6);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); compute_encoder.set_vector_bytes(strides_c, 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else { } else {
// The shape is implicit in the grid for <= 3D // The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); compute_encoder.set_vector_bytes(strides_a, 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); compute_encoder.set_vector_bytes(strides_b, 5);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(strides_c, 6);
} }
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
@ -99,7 +99,7 @@ void ternary_op_gpu_inplace(
} }
MTL::Size group_dims = get_block_dims(dim0, dim1, rest); MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
@ -109,7 +109,7 @@ void ternary_op_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -49,7 +49,7 @@ void unary_op_gpu_inplace(
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array( compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0); in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@ -58,16 +58,16 @@ void unary_op_gpu_inplace(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); compute_encoder.set_vector_bytes(shape, 2);
compute_encoder->setBytes(strides.data(), ndim * sizeof(size_t), 3); compute_encoder.set_vector_bytes(strides, 3);
compute_encoder->setBytes(&ndim, sizeof(int), 4); compute_encoder.set_bytes(ndim, 4);
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::unary] Must use 1024 sized block"); throw std::runtime_error("[Metal::unary] Must use 1024 sized block");
} }
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
@ -75,7 +75,7 @@ void unary_op_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@ -8,23 +8,6 @@
namespace mlx::core { namespace mlx::core {
using metal::CommandEncoder;
template <typename T>
inline void set_vector_bytes(
CommandEncoder& enc,
const std::vector<T>& vec,
size_t nelems,
int idx) {
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
inline void
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
std::string type_to_name(const array& a); std::string type_to_name(const array& a);
// Compute the thread block dimensions which fit the given // Compute the thread block dimensions which fit the given