mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Explicit barriers with concurrent dispatch (#977)
This commit is contained in:
parent
8580d997ff
commit
12d4507ee3
@ -289,7 +289,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
@ -300,7 +300,7 @@ void Compiled::eval_gpu(
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
@ -315,7 +315,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_output_array(x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
|
@ -41,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
@ -140,7 +140,7 @@ void slow_conv_2D_gpu(
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -153,9 +153,9 @@ void slow_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@ -241,7 +241,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -254,9 +254,9 @@ void implicit_gemm_conv_2D_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@ -394,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -408,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@ -511,12 +511,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
@ -539,12 +539,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
@ -587,12 +587,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
@ -83,15 +83,15 @@ void copy_gpu_inplace(
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
out_offset *= size_of(out.dtype());
|
||||
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
|
||||
set_array_buffer(compute_encoder, out, out_offset, 1);
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
int ndim = shape.size();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
@ -206,14 +206,15 @@ void Device::end_encoding(int index) {
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder = cb->computeCommandEncoder();
|
||||
auto compute_encoder =
|
||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||
eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first;
|
||||
}
|
||||
return eit->second;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -7,10 +7,12 @@
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@ -34,6 +36,69 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false){};
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.concurrent = true;
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent = false;
|
||||
enc.outputs.insert(
|
||||
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
|
||||
enc.concurrent_outputs.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent;
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
@ -51,7 +116,7 @@ class Device {
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
void commit_command_buffer(int index);
|
||||
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
@ -132,7 +197,7 @@ class Device {
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||
std::unordered_map<int32_t, CommandEncoder> encoder_map_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
|
@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_" << idx_ndim;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
@ -103,7 +103,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto& upd = inputs.back();
|
||||
@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
@ -210,7 +210,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@ -280,7 +280,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
|
@ -336,7 +336,7 @@ void steel_matmul(
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -360,9 +360,9 @@ void steel_matmul(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, C_split, 2);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@ -380,8 +380,8 @@ void steel_matmul(
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, C_split, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
@ -426,7 +426,7 @@ void steel_matmul(
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -467,9 +467,9 @@ void steel_matmul(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
@ -622,7 +622,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -630,9 +630,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
@ -834,7 +834,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -842,10 +842,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
@ -907,7 +907,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -931,9 +931,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, C_split, 2);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@ -946,12 +946,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, C_split, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
set_array_buffer(compute_encoder, c, 5);
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||
@ -997,7 +997,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -1045,10 +1045,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
||||
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
|
@ -88,7 +88,6 @@ std::function<void()> make_task(
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
|
@ -57,7 +57,7 @@ void RMSNorm::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@ -79,10 +79,10 @@ void RMSNorm::eval_gpu(
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
|
||||
@ -160,7 +160,7 @@ void RMSNormVJP::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@ -182,12 +182,11 @@ void RMSNormVJP::eval_gpu(
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
@ -251,7 +250,7 @@ void LayerNorm::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@ -274,11 +273,11 @@ void LayerNorm::eval_gpu(
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, b, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(b, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
|
||||
@ -350,7 +349,7 @@ void LayerNormVJP::eval_gpu(
|
||||
}
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
@ -394,12 +393,11 @@ void LayerNormVJP::eval_gpu(
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
|
@ -68,18 +68,18 @@ void binary_op(
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
|
||||
set_array_buffer(
|
||||
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
compute_encoder.set_output_array(outputs[0], 2);
|
||||
compute_encoder.set_output_array(outputs[1], 3);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
@ -167,13 +167,13 @@ void binary_op(
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
|
||||
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
@ -253,12 +253,12 @@ void ternary_op(
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
@ -339,11 +339,11 @@ void unary_op(
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
if (!contig) {
|
||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||
compute_encoder->setBytes(
|
||||
@ -365,7 +365,7 @@ void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
|
||||
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
|
||||
enc->setBytes(&start, sizeof(T), 0);
|
||||
T step = next - start;
|
||||
enc->setBytes(&step, sizeof(T), 1);
|
||||
@ -384,7 +384,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
switch (out.dtype()) {
|
||||
@ -427,7 +427,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@ -487,7 +487,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// ArgReduce
|
||||
int simd_size = 32;
|
||||
int n_reads = 4;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
||||
NS::UInteger thread_group_size = std::min(
|
||||
@ -502,8 +502,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
if (ndim == 0) {
|
||||
// Pass place holders so metal doesn't complain
|
||||
int shape_ = 0;
|
||||
@ -552,6 +552,9 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
auto& d = metal::device(stream().device);
|
||||
auto& compute_encoder = d.get_command_encoder(stream().index);
|
||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
@ -791,10 +794,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, keys, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(keys, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&odd, sizeof(bool), 2);
|
||||
compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3);
|
||||
|
||||
|
@ -48,7 +48,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_ << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -57,11 +57,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -84,11 +84,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
set_array_buffer(compute_encoder, biases, 2);
|
||||
set_array_buffer(compute_encoder, x, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
@ -102,7 +102,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -114,11 +114,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
@ -133,7 +133,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -142,11 +142,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
@ -160,7 +160,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -179,11 +179,11 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, scales, 2);
|
||||
set_array_buffer(compute_encoder, biases, 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
@ -35,7 +35,7 @@ void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
@ -71,8 +71,8 @@ void all_reduce_dispatch(
|
||||
|
||||
// Encode buffers and dispatch
|
||||
if (is_out_64b_int == false || n_thread_groups == 1) {
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
@ -85,14 +85,14 @@ void all_reduce_dispatch(
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// First dispatch
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Second pass to reduce intermediate reduction results written to DRAM
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
|
||||
|
||||
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
|
||||
@ -123,7 +123,7 @@ void row_reduce_general_dispatch(
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
@ -208,8 +208,8 @@ void row_reduce_general_dispatch(
|
||||
// Dispatch kernel
|
||||
if (!is_out_64b_int || non_row_reductions == 1) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
@ -230,8 +230,8 @@ void row_reduce_general_dispatch(
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
@ -258,8 +258,8 @@ void row_reduce_general_dispatch(
|
||||
ndim = new_shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
@ -301,7 +301,7 @@ void strided_reduce_general_dispatch(
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
@ -349,8 +349,8 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
@ -415,8 +415,8 @@ void strided_reduce_general_dispatch(
|
||||
|
||||
if (is_out_64b_int == false) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
@ -450,8 +450,8 @@ void strided_reduce_general_dispatch(
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
@ -494,8 +494,8 @@ void strided_reduce_general_dispatch(
|
||||
"row_reduce_general_no_atomics_" + op_name +
|
||||
type_to_name(intermediate));
|
||||
compute_encoder->setComputePipelineState(row_reduce_kernel);
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4);
|
||||
@ -573,7 +573,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Initialize output
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
|
||||
size_t nthreads = out.size();
|
||||
@ -584,7 +584,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, out, 0);
|
||||
compute_encoder.set_output_array(out, 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@ -8,11 +8,13 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using metal::CommandEncoder;
|
||||
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
@ -22,7 +24,7 @@ void row_reduce_general_dispatch(
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
@ -32,7 +34,7 @@ void strided_reduce_general_dispatch(
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
|
@ -66,12 +66,12 @@ void RoPE::eval_gpu(
|
||||
kname << "rope_" << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 4);
|
||||
|
@ -71,7 +71,7 @@ void sdpa_metal(
|
||||
|
||||
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
|
||||
kname_partials << kname_suffix;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_partials.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -87,15 +87,15 @@ void sdpa_metal(
|
||||
MLXScaledDotProductAttentionParams params{
|
||||
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
|
||||
|
||||
set_array_buffer(compute_encoder, q, 0);
|
||||
set_array_buffer(compute_encoder, k, 1);
|
||||
set_array_buffer(compute_encoder, v, 2);
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
|
||||
set_array_buffer(compute_encoder, o_partial, 5);
|
||||
set_array_buffer(compute_encoder, p_lse, 6);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
||||
compute_encoder.set_input_array(o_partial, 5);
|
||||
compute_encoder.set_input_array(p_lse, 6);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
@ -104,12 +104,12 @@ void sdpa_metal(
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
compute_encoder->setComputePipelineState(kernel_accum);
|
||||
set_array_buffer(compute_encoder, o_partial, 0);
|
||||
set_array_buffer(compute_encoder, p_lse, 1);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 2);
|
||||
compute_encoder.set_input_array(o_partial, 0);
|
||||
compute_encoder.set_input_array(p_lse, 1);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 2);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
@ -52,10 +52,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
|
||||
@ -101,10 +101,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << type_to_name(in) << "_" << type_to_name(out);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
size_t stride = in.strides()[axis_];
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
|
@ -60,7 +60,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name += "precise_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
@ -81,9 +81,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
@ -57,13 +57,13 @@ void single_block_sort(
|
||||
}
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
|
||||
@ -131,7 +131,7 @@ void multi_block_sort(
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
// Do blockwise sort
|
||||
{
|
||||
@ -142,9 +142,9 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_0, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_0, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(dev_vals_0, 1);
|
||||
compute_encoder.set_output_array(dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
@ -181,9 +181,9 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
|
||||
@ -202,11 +202,11 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
set_array_buffer(compute_encoder, dev_vals_out, 3);
|
||||
set_array_buffer(compute_encoder, dev_idxs_out, 4);
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(dev_vals_out, 3);
|
||||
compute_encoder.set_output_array(dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -10,29 +10,11 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline void
|
||||
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
inline void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int64_t offset,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
using metal::CommandEncoder;
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
CommandEncoder& enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
@ -40,10 +22,8 @@ inline void set_vector_bytes(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
int idx) {
|
||||
inline void
|
||||
set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
|
33
mlx/ops.cpp
33
mlx/ops.cpp
@ -1050,19 +1050,30 @@ array transpose(
|
||||
for (auto& ax : axes) {
|
||||
ax = ax < 0 ? ax + a.ndim() : ax;
|
||||
}
|
||||
std::set dims(axes.begin(), axes.end());
|
||||
if (dims.size() != axes.size()) {
|
||||
throw std::invalid_argument("Repeat axes not allowed in transpose.");
|
||||
if (axes.size() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[transpose] Recived " << axes.size() << " axes for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (dims.size() != a.ndim() ||
|
||||
a.ndim() > 0 &&
|
||||
(*dims.begin() != 0 || *dims.rbegin() != (a.ndim() - 1))) {
|
||||
throw std::invalid_argument("Transpose axes don't match array dimensions.");
|
||||
|
||||
// Check in bounds and for duplicates
|
||||
std::vector<int> shape(axes.size(), 0);
|
||||
for (auto& ax : axes) {
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[transpose] Invalid axis (" << ax << ") for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<int> shape;
|
||||
shape.reserve(axes.size());
|
||||
for (auto ax : axes) {
|
||||
shape.push_back(a.shape()[ax]);
|
||||
if (shape[ax] != 0) {
|
||||
throw std::invalid_argument("[transpose] Repeat axes not allowed.");
|
||||
}
|
||||
shape[ax] = 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
shape[i] = a.shape()[axes[i]];
|
||||
}
|
||||
return array(
|
||||
std::move(shape),
|
||||
|
@ -33,10 +33,10 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
}
|
||||
|
||||
inline array to_array_with_accessor(nb::object obj) {
|
||||
if (nb::hasattr(obj, "__mlx_array__")) {
|
||||
return nb::cast<array>(obj.attr("__mlx_array__")());
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
return nb::cast<array>(obj);
|
||||
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
||||
return nb::cast<array>(obj.attr("__mlx_array__")());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
|
Loading…
Reference in New Issue
Block a user