mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (metal)
This commit is contained in:
@@ -109,7 +109,7 @@ inline void build_kernel(
|
||||
|
||||
// Read constant / contiguous inputs in tmps
|
||||
std::vector<array> nc_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
@@ -134,7 +134,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
@@ -174,7 +174,7 @@ inline void build_kernel(
|
||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
}
|
||||
os += " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" index_{0} += ", xname);
|
||||
if (dynamic_dims) {
|
||||
@@ -195,7 +195,7 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read non-contiguous inputs into tmps
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os += fmt::format(
|
||||
@@ -214,7 +214,7 @@ inline void build_kernel(
|
||||
} else {
|
||||
os += x.primitive().name();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
@@ -227,7 +227,7 @@ inline void build_kernel(
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
@@ -396,7 +396,7 @@ void Compiled::eval_gpu(
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
Strides in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -990,7 +990,7 @@ void conv_3D_gpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
std::vector<array>& /* copies */) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
|
||||
@@ -68,7 +68,7 @@ std::string write_signature(
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
@@ -109,7 +109,7 @@ std::string write_signature(
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(output_names); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
@@ -126,8 +126,8 @@ std::string write_signature(
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
if (index < std::ssize(inputs) + std::ssize(output_names) - 1 ||
|
||||
std::ssize(attributes) > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
@@ -138,7 +138,7 @@ std::string write_signature(
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
if (index < std::ssize(attributes) - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
@@ -381,7 +381,7 @@ void CustomKernel::eval_gpu(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(checked_inputs); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
@@ -408,7 +408,7 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
auto tg_size = tx * ty * tz;
|
||||
unsigned long tg_size = tx * ty * tz;
|
||||
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (tg_size > max_tg_size) {
|
||||
std::ostringstream msg;
|
||||
|
||||
@@ -127,6 +127,9 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
(void)device;
|
||||
(void)lib_name;
|
||||
#endif
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
@@ -713,7 +716,7 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
||||
|
||||
std::vector<NS::Object*> objs(funcs.size());
|
||||
for (int i = 0; i < funcs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(funcs); i++) {
|
||||
objs[i] = funcs[i];
|
||||
}
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ struct DeviceStream {
|
||||
// Data updated between command buffers
|
||||
MTL::CommandBuffer* buffer{nullptr};
|
||||
int buffer_ops{0};
|
||||
size_t buffer_sizes{0};
|
||||
int64_t buffer_sizes{0};
|
||||
|
||||
// The command encoder, fence, and temporaries are updated between command
|
||||
// encoders
|
||||
|
||||
@@ -76,7 +76,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
auto command_buffer = d.get_command_buffer(idx);
|
||||
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
}
|
||||
|
||||
void Fence::update(Stream stream, const array& x) {
|
||||
@@ -124,7 +124,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(f.fence), f.count);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -60,7 +60,7 @@ struct FourStepParams {
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
int64_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
@@ -93,7 +93,7 @@ std::vector<int> plan_stockham_fft(int n) {
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(radices); i++) {
|
||||
int radix = radices[i];
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
@@ -181,7 +181,7 @@ int compute_elems_per_thread(FFTPlan plan) {
|
||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(steps); i++) {
|
||||
int radix = radices[i % radices.size()];
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radix);
|
||||
@@ -260,7 +260,7 @@ int primitive_root(int n) {
|
||||
|
||||
std::tuple<array, array, array> compute_raders_constants(
|
||||
int rader_n,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
int proot = primitive_root(rader_n);
|
||||
// Fermat's little theorem
|
||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||
@@ -508,7 +508,7 @@ void four_step_fft(
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
int64_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
@@ -612,11 +612,11 @@ void fft_op(
|
||||
|
||||
// Start of radix/rader step constants
|
||||
int index = 4;
|
||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(plan.stockham); i++) {
|
||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||
index += 1;
|
||||
}
|
||||
for (int i = 0; i < plan.rader.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(plan.rader); i++) {
|
||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||
index += 1;
|
||||
}
|
||||
@@ -771,8 +771,8 @@ void nd_fft_op(
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
for (int i = std::ssize(axes) - 1; i >= 0; i--) {
|
||||
int reverse_index = std::ssize(axes) - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
@@ -780,8 +780,8 @@ void nd_fft_op(
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
bool step_real = (real && index == std::ssize(axes) - 1);
|
||||
const array& in_arr = i == std::ssize(axes) - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ std::string gen_hadamard_codelet(int m) {
|
||||
while (end != std::string_view::npos) {
|
||||
source << " tmp[" << index << "] = ";
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
for (int i = 0; i < std::ssize(row); i++) {
|
||||
source << " " << row[i] << " x[" << i << "]";
|
||||
}
|
||||
source << ";" << std::endl;
|
||||
|
||||
@@ -52,7 +52,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t slice_size = 1;
|
||||
int64_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
@@ -94,8 +94,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
||||
size_t dim_y = indices.size();
|
||||
int64_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
||||
int64_t dim_y = indices.size();
|
||||
auto group_dims = get_block_dims(dim_x, dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
|
||||
|
||||
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
int64_t ndim = src.ndim();
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
@@ -149,8 +149,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Launch 3D grid of threads
|
||||
// First two dimensions for the indices, the last one for the slice
|
||||
size_t dim0 = 1;
|
||||
size_t dim1 = 1;
|
||||
int64_t dim0 = 1;
|
||||
int64_t dim1 = 1;
|
||||
if (nidx) {
|
||||
if (inputs[1].ndim() >= 1) {
|
||||
dim0 = inputs[1].shape(0);
|
||||
@@ -159,13 +159,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dim1 = inputs[1].size() / dim0;
|
||||
}
|
||||
}
|
||||
size_t dim2 = slice_size;
|
||||
int64_t dim2 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
std::vector<int64_t> idx_strides;
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
int64_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
|
||||
auto idx_to_out = idx_size / out.size();
|
||||
int nwork;
|
||||
@@ -345,7 +345,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
size_t nthreads = upd.size();
|
||||
int64_t nthreads = upd.size();
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -354,8 +354,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
int64_t upd_ndim = upd.ndim();
|
||||
int64_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
@@ -391,7 +391,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_bytes(upd_size, 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
int64_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't complain
|
||||
int shape_ = 0;
|
||||
@@ -448,7 +448,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
int64_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
@@ -486,8 +486,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
int64_t size_pre = 1;
|
||||
int64_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
@@ -541,7 +541,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
int64_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
@@ -602,8 +602,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
int64_t size_pre = 1;
|
||||
int64_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
|
||||
@@ -344,7 +344,7 @@ void steel_gemm_splitk_axpby(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int /* batch_size_out */,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
|
||||
@@ -179,8 +179,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
const std::optional<array>& /* mask_out */,
|
||||
const std::optional<array>& /* mask_op */,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
|
||||
@@ -134,7 +134,7 @@ void RMSNormVJP::eval_gpu(
|
||||
d.add_temporary(g, s.index);
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
auto axis_size = x.shape().back();
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate the gradient accumulator gw and a temporary to store the
|
||||
@@ -246,7 +246,7 @@ void LayerNorm::eval_gpu(
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
auto axis_size = x.shape().back();
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
int simd_size = 32;
|
||||
@@ -344,7 +344,7 @@ void LayerNormVJP::eval_gpu(
|
||||
d.add_temporary(g, s.index);
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
auto axis_size = x.shape().back();
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
|
||||
@@ -152,7 +152,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Load::eval_gpu(const std::vector<array>& /* inputs */, array& /* out */) {
|
||||
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
||||
}
|
||||
|
||||
@@ -201,41 +201,45 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
|
||||
}
|
||||
|
||||
void SVD::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
||||
}
|
||||
|
||||
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
||||
void Inverse::eval_gpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
array& /* output */) {
|
||||
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
|
||||
}
|
||||
|
||||
void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Cholesky::eval_gpu(
|
||||
const std::vector<array>& /* inputs */,
|
||||
array& /* out */) {
|
||||
throw std::runtime_error(
|
||||
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
|
||||
}
|
||||
|
||||
void Eig::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
|
||||
}
|
||||
|
||||
void Eigh::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
|
||||
}
|
||||
|
||||
void LUF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<array>& /* inputs */,
|
||||
std::vector<array>& /* outputs */) {
|
||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||
}
|
||||
|
||||
|
||||
@@ -291,7 +291,7 @@ void init_reduce(
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [_, out_type] = remap_reduce_types(out, op_name);
|
||||
const std::string func_name = "init_reduce";
|
||||
std::string kname = func_name;
|
||||
@@ -397,7 +397,7 @@ void row_reduce_small(
|
||||
RowReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
// Set the kernel
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
@@ -453,7 +453,7 @@ void row_reduce_simple(
|
||||
RowReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
// Set the kernel
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "row_reduce_simple";
|
||||
@@ -493,7 +493,7 @@ void row_reduce_looped(
|
||||
RowReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Set the kernel
|
||||
@@ -570,7 +570,7 @@ void strided_reduce_small(
|
||||
ColReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Figure out the grid dims
|
||||
@@ -747,7 +747,7 @@ void strided_reduce_looped(
|
||||
ColReduceArgs& args,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& /* s */) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
@@ -959,7 +959,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Continue with reduction operation
|
||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||
// and metal will complain o/w
|
||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
||||
size_t min_bytes = std::max<int64_t>(out.nbytes(), 4);
|
||||
out.set_data(allocator::malloc(min_bytes));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
|
||||
@@ -80,7 +80,7 @@ void ResidencySet::resize(size_t size) {
|
||||
// Remove wired allocations until under capacity
|
||||
auto allocations = wired_set_->allAllocations();
|
||||
auto num_allocations = wired_set_->allocationCount();
|
||||
for (int i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
for (size_t i = 0; i < num_allocations && current_size > size; ++i) {
|
||||
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
||||
wired_set_->removeAllocation(buf);
|
||||
current_size -= buf->allocatedSize();
|
||||
|
||||
@@ -33,7 +33,7 @@ void concatenate_gpu(
|
||||
auto& d = metal::device(s.device);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
|
||||
@@ -29,6 +29,10 @@ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||
std::ostringstream label;
|
||||
label << "Stream " << index;
|
||||
queue->setLabel(make_string(label));
|
||||
#else
|
||||
// appease warnings
|
||||
(void)queue;
|
||||
(void)index;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -42,6 +46,9 @@ inline void debug_set_primitive_buffer_label(
|
||||
}
|
||||
label << primitive.name();
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#else
|
||||
(void)command_buffer;
|
||||
(void)primitive;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user