WIP (metal)

This commit is contained in:
Ronan Collobert
2025-10-31 09:43:29 -07:00
parent 981d2fdaf0
commit 979abf462b
17 changed files with 94 additions and 80 deletions

View File

@@ -109,7 +109,7 @@ inline void build_kernel(
// Read constant / contiguous inputs in tmps // Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs; std::vector<array> nc_inputs;
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < std::ssize(inputs); ++i) {
auto& x = inputs[i]; auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
@@ -134,7 +134,7 @@ inline void build_kernel(
} }
// Initialize the indices for non-contiguous inputs // Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < std::ssize(nc_inputs); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname); os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) { if (ndim == 1) {
@@ -174,7 +174,7 @@ inline void build_kernel(
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
} }
os += " uint l = zpos % output_shape[d];\n"; os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < std::ssize(nc_inputs); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname); os += fmt::format(" index_{0} += ", xname);
if (dynamic_dims) { if (dynamic_dims) {
@@ -195,7 +195,7 @@ inline void build_kernel(
} }
// Read non-contiguous inputs into tmps // Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < std::ssize(nc_inputs); ++i) {
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
os += fmt::format( os += fmt::format(
@@ -214,7 +214,7 @@ inline void build_kernel(
} else { } else {
os += x.primitive().name(); os += x.primitive().name();
os += "()("; os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
} }
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
@@ -227,7 +227,7 @@ inline void build_kernel(
} }
// Increment indices and close per thread loop // Increment indices and close per thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < std::ssize(nc_inputs); ++i) {
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (!dynamic_dims) { if (!dynamic_dims) {
@@ -396,7 +396,7 @@ void Compiled::eval_gpu(
int cnt = 0; int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides int stride_idx = 1; // idx 0 is the output strides
Strides in_strides; Strides in_strides;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < std::ssize(inputs); i++) {
if (is_constant_(i)) { if (is_constant_(i)) {
continue; continue;
} }

View File

@@ -990,7 +990,7 @@ void conv_3D_gpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip, bool flip,
std::vector<array>& copies) { std::vector<array>& /* copies */) {
// Make conv params // Make conv params
MLXConvParams<3> conv_params{ MLXConvParams<3> conv_params{
/* const int N = */ static_cast<int>(in.shape(0)), /* const int N = */ static_cast<int>(in.shape(0)),

View File

@@ -68,7 +68,7 @@ std::string write_signature(
int index = 0; int index = 0;
constexpr int max_constant_array_size = 8; constexpr int max_constant_array_size = 8;
// Add inputs // Add inputs
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < std::ssize(inputs); ++i) {
const auto& name = input_names[i]; const auto& name = input_names[i];
const auto& arr = inputs[i]; const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype()); auto dtype = get_type_string(arr.dtype());
@@ -109,7 +109,7 @@ std::string write_signature(
} }
} }
// Add outputs // Add outputs
for (int i = 0; i < output_names.size(); ++i) { for (int i = 0; i < std::ssize(output_names); ++i) {
const auto& name = output_names[i]; const auto& name = output_names[i];
const auto& dtype = output_dtypes[i]; const auto& dtype = output_dtypes[i];
kernel_source += " device "; kernel_source += " device ";
@@ -126,8 +126,8 @@ std::string write_signature(
kernel_source += " [[buffer("; kernel_source += " [[buffer(";
kernel_source += std::to_string(index); kernel_source += std::to_string(index);
kernel_source += ")]]"; kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 || if (index < std::ssize(inputs) + std::ssize(output_names) - 1 ||
attributes.size() > 0) { std::ssize(attributes) > 0) {
kernel_source += ",\n"; kernel_source += ",\n";
} else { } else {
kernel_source += ") {\n"; kernel_source += ") {\n";
@@ -138,7 +138,7 @@ std::string write_signature(
index = 0; index = 0;
for (const auto& attr : attributes) { for (const auto& attr : attributes) {
kernel_source += attr; kernel_source += attr;
if (index < attributes.size() - 1) { if (index < std::ssize(attributes) - 1) {
kernel_source += ",\n"; kernel_source += ",\n";
} else { } else {
kernel_source += ") {\n"; kernel_source += ") {\n";
@@ -381,7 +381,7 @@ void CustomKernel::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int index = 0; int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) { for (int i = 0; i < std::ssize(checked_inputs); i++) {
const array& in = checked_inputs[i]; const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i]; auto& shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index); compute_encoder.set_input_array(in, index);
@@ -408,7 +408,7 @@ void CustomKernel::eval_gpu(
} }
const auto [tx, ty, tz] = threadgroup_; const auto [tx, ty, tz] = threadgroup_;
auto tg_size = tx * ty * tz; unsigned long tg_size = tx * ty * tz;
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
if (tg_size > max_tg_size) { if (tg_size > max_tg_size) {
std::ostringstream msg; std::ostringstream msg;

View File

@@ -127,6 +127,9 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
} }
} }
} }
#else
(void)device;
(void)lib_name;
#endif #endif
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
@@ -713,7 +716,7 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
auto lfuncs = MTL::LinkedFunctions::linkedFunctions(); auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
std::vector<NS::Object*> objs(funcs.size()); std::vector<NS::Object*> objs(funcs.size());
for (int i = 0; i < funcs.size(); i++) { for (int i = 0; i < std::ssize(funcs); i++) {
objs[i] = funcs[i]; objs[i] = funcs[i];
} }

View File

@@ -137,7 +137,7 @@ struct DeviceStream {
// Data updated between command buffers // Data updated between command buffers
MTL::CommandBuffer* buffer{nullptr}; MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0}; int buffer_ops{0};
size_t buffer_sizes{0}; int64_t buffer_sizes{0};
// The command encoder, fence, and temporaries are updated between command // The command encoder, fence, and temporaries are updated between command
// encoders // encoders

View File

@@ -76,7 +76,7 @@ void Fence::wait(Stream stream, const array& x) {
auto command_buffer = d.get_command_buffer(idx); auto command_buffer = d.get_command_buffer(idx);
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count); command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {}); [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
return; return;
} }
@@ -96,7 +96,7 @@ void Fence::wait(Stream stream, const array& x) {
compute_encoder.dispatch_threads(kernel_dims, kernel_dims); compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler( d.get_command_buffer(idx)->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {}); [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
} }
void Fence::update(Stream stream, const array& x) { void Fence::update(Stream stream, const array& x) {
@@ -124,7 +124,7 @@ void Fence::update(Stream stream, const array& x) {
command_buffer->encodeSignalEvent( command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(f.fence), f.count); static_cast<MTL::Event*>(f.fence), f.count);
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {}); [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
return; return;
} }
@@ -154,7 +154,7 @@ void Fence::update(Stream stream, const array& x) {
compute_encoder.dispatch_threads(kernel_dims, kernel_dims); compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler( d.get_command_buffer(idx)->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {}); [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -60,7 +60,7 @@ struct FourStepParams {
void fft_op( void fft_op(
const array& in, const array& in,
array& out, array& out,
size_t axis, int64_t axis,
bool inverse, bool inverse,
bool real, bool real,
const FourStepParams four_step_params, const FourStepParams four_step_params,
@@ -93,7 +93,7 @@ std::vector<int> plan_stockham_fft(int n) {
if (n == 1) { if (n == 1) {
return plan; return plan;
} }
for (int i = 0; i < radices.size(); i++) { for (int i = 0; i < std::ssize(radices); i++) {
int radix = radices[i]; int radix = radices[i];
// Manually tuned radices for powers of 2 // Manually tuned radices for powers of 2
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) { if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
@@ -181,7 +181,7 @@ int compute_elems_per_thread(FFTPlan plan) {
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end()); steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end()); steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
std::set<int> used_radices; std::set<int> used_radices;
for (int i = 0; i < steps.size(); i++) { for (int i = 0; i < std::ssize(steps); i++) {
int radix = radices[i % radices.size()]; int radix = radices[i % radices.size()];
if (steps[i] > 0) { if (steps[i] > 0) {
used_radices.insert(radix); used_radices.insert(radix);
@@ -260,7 +260,7 @@ int primitive_root(int n) {
std::tuple<array, array, array> compute_raders_constants( std::tuple<array, array, array> compute_raders_constants(
int rader_n, int rader_n,
const Stream& s) { const Stream& /* s */) {
int proot = primitive_root(rader_n); int proot = primitive_root(rader_n);
// Fermat's little theorem // Fermat's little theorem
int inv = mod_exp(proot, rader_n - 2, rader_n); int inv = mod_exp(proot, rader_n - 2, rader_n);
@@ -508,7 +508,7 @@ void four_step_fft(
void fft_op( void fft_op(
const array& in, const array& in,
array& out, array& out,
size_t axis, int64_t axis,
bool inverse, bool inverse,
bool real, bool real,
const FourStepParams four_step_params, const FourStepParams four_step_params,
@@ -612,11 +612,11 @@ void fft_op(
// Start of radix/rader step constants // Start of radix/rader step constants
int index = 4; int index = 4;
for (int i = 0; i < plan.stockham.size(); i++) { for (int i = 0; i < std::ssize(plan.stockham); i++) {
func_consts.push_back(make_int(&plan.stockham[i], index)); func_consts.push_back(make_int(&plan.stockham[i], index));
index += 1; index += 1;
} }
for (int i = 0; i < plan.rader.size(); i++) { for (int i = 0; i < std::ssize(plan.rader); i++) {
func_consts.push_back(make_int(&plan.rader[i], index)); func_consts.push_back(make_int(&plan.rader[i], index));
index += 1; index += 1;
} }
@@ -771,8 +771,8 @@ void nd_fft_op(
array temp1(temp_shape, complex64, nullptr, {}); array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {}); array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2}; std::vector<array> temp_arrs = {temp1, temp2};
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = std::ssize(axes) - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1; int reverse_index = std::ssize(axes) - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays // For 5D and above, we don't want to reallocate our two temporary arrays
bool inplace = reverse_index >= 3 && i != 0; bool inplace = reverse_index >= 3 && i != 0;
// Opposite order for fft vs ifft // Opposite order for fft vs ifft
@@ -780,8 +780,8 @@ void nd_fft_op(
size_t axis = axes[index]; size_t axis = axes[index];
// Mirror np.fft.(i)rfftn and perform a real transform // Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis. // only on the final axis.
bool step_real = (real && index == axes.size() - 1); bool step_real = (real && index == std::ssize(axes) - 1);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; const array& in_arr = i == std::ssize(axes) - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2]; array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
} }

View File

@@ -43,7 +43,7 @@ std::string gen_hadamard_codelet(int m) {
while (end != std::string_view::npos) { while (end != std::string_view::npos) {
source << " tmp[" << index << "] = "; source << " tmp[" << index << "] = ";
auto row = matrix.substr(start, end - start); auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) { for (int i = 0; i < std::ssize(row); i++) {
source << " " << row[i] << " x[" << i << "]"; source << " " << row[i] << " x[" << i << "]";
} }
source << ";" << std::endl; source << ";" << std::endl;

View File

@@ -52,7 +52,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t slice_size = 1; int64_t slice_size = 1;
for (auto s : slice_sizes_) { for (auto s : slice_sizes_) {
slice_size *= s; slice_size *= s;
} }
@@ -94,8 +94,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; int64_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
size_t dim_y = indices.size(); int64_t dim_y = indices.size();
auto group_dims = get_block_dims(dim_x, dim_y, 1); auto group_dims = get_block_dims(dim_x, dim_y, 1);
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1); MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim(); int64_t ndim = src.ndim();
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}", "gather{0}{1}_{2}_{3}_{4}",
@@ -149,8 +149,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch 3D grid of threads // Launch 3D grid of threads
// First two dimensions for the indices, the last one for the slice // First two dimensions for the indices, the last one for the slice
size_t dim0 = 1; int64_t dim0 = 1;
size_t dim1 = 1; int64_t dim1 = 1;
if (nidx) { if (nidx) {
if (inputs[1].ndim() >= 1) { if (inputs[1].ndim() >= 1) {
dim0 = inputs[1].shape(0); dim0 = inputs[1].shape(0);
@@ -159,13 +159,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dim1 = inputs[1].size() / dim0; dim1 = inputs[1].size() / dim0;
} }
} }
size_t dim2 = slice_size; int64_t dim2 = slice_size;
auto group_dims = get_block_dims(dim0, dim1, dim2); auto group_dims = get_block_dims(dim0, dim1, dim2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2); MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
// Collect all idx shapes and strides into one place // Collect all idx shapes and strides into one place
std::vector<int> idx_shapes; std::vector<int> idx_shapes;
std::vector<size_t> idx_strides; std::vector<int64_t> idx_strides;
std::vector<char> idx_contigs; std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
idx_shapes.insert( idx_shapes.insert(
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t idx_size = nidx ? inputs[1].size() : 1; int64_t idx_size = nidx ? inputs[1].size() : 1;
auto idx_to_out = idx_size / out.size(); auto idx_to_out = idx_size / out.size();
int nwork; int nwork;
@@ -345,7 +345,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
size_t nthreads = upd.size(); int64_t nthreads = upd.size();
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -354,8 +354,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Set update info // Set update info
size_t upd_ndim = upd.ndim(); int64_t upd_ndim = upd.ndim();
size_t upd_size = 1; int64_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) { for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i); upd_size *= upd.shape(i);
} }
@@ -391,7 +391,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_bytes(upd_size, 6); compute_encoder.set_bytes(upd_size, 6);
// Set output info // Set output info
size_t out_ndim = out.ndim(); int64_t out_ndim = out.ndim();
if (out_ndim == 0) { if (out_ndim == 0) {
// Need placeholders so Metal doesn't complain // Need placeholders so Metal doesn't complain
int shape_ = 0; int shape_ = 0;
@@ -448,7 +448,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t ndim = src.ndim(); int64_t ndim = src.ndim();
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
@@ -486,8 +486,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Grid [size post, index size, size pre] // Grid [size post, index size, size pre]
size_t size_pre = 1; int64_t size_pre = 1;
size_t size_post = 1; int64_t size_post = 1;
for (int i = 0; i < axis_; ++i) { for (int i = 0; i < axis_; ++i) {
size_pre *= idx.shape(i); size_pre *= idx.shape(i);
} }
@@ -541,7 +541,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t ndim = src.ndim(); int64_t ndim = src.ndim();
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
@@ -602,8 +602,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Grid [size post, index size, size pre] // Grid [size post, index size, size pre]
size_t size_pre = 1; int64_t size_pre = 1;
size_t size_post = 1; int64_t size_post = 1;
for (int i = 0; i < axis_; ++i) { for (int i = 0; i < axis_; ++i) {
size_pre *= idx.shape(i); size_pre *= idx.shape(i);
} }

View File

@@ -344,7 +344,7 @@ void steel_gemm_splitk_axpby(
int M, int M,
int N, int N,
int K, int K,
int batch_size_out, int /* batch_size_out */,
int lda, int lda,
int ldb, int ldb,
bool transpose_a, bool transpose_a,

View File

@@ -179,8 +179,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const array&, const array&,
const std::optional<array>& mask_out, const std::optional<array>& /* mask_out */,
const std::optional<array>& mask_op, const std::optional<array>& /* mask_op */,
bool, bool,
bool, bool,
int, int,

View File

@@ -134,7 +134,7 @@ void RMSNormVJP::eval_gpu(
d.add_temporary(g, s.index); d.add_temporary(g, s.index);
} }
auto axis_size = static_cast<uint32_t>(x.shape().back()); auto axis_size = x.shape().back();
int n_rows = x.data_size() / axis_size; int n_rows = x.data_size() / axis_size;
// Allocate the gradient accumulator gw and a temporary to store the // Allocate the gradient accumulator gw and a temporary to store the
@@ -246,7 +246,7 @@ void LayerNorm::eval_gpu(
const array& w = inputs[1]; const array& w = inputs[1];
const array& b = inputs[2]; const array& b = inputs[2];
auto axis_size = static_cast<uint32_t>(x.shape().back()); auto axis_size = x.shape().back();
int n_rows = x.data_size() / axis_size; int n_rows = x.data_size() / axis_size;
int simd_size = 32; int simd_size = 32;
@@ -344,7 +344,7 @@ void LayerNormVJP::eval_gpu(
d.add_temporary(g, s.index); d.add_temporary(g, s.index);
} }
auto axis_size = static_cast<uint32_t>(x.shape().back()); auto axis_size = x.shape().back();
int n_rows = x.data_size() / axis_size; int n_rows = x.data_size() / axis_size;
// Allocate a temporary to store the gradients for w and allocate the output // Allocate a temporary to store the gradients for w and allocate the output

View File

@@ -152,7 +152,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Load::eval_gpu(const std::vector<array>& inputs, array& out) { void Load::eval_gpu(const std::vector<array>& /* inputs */, array& /* out */) {
throw std::runtime_error("[Load::eval_gpu] Not implemented."); throw std::runtime_error("[Load::eval_gpu] Not implemented.");
} }
@@ -201,41 +201,45 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
void QRF::eval_gpu( void QRF::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& /* inputs */,
std::vector<array>& outputs) { std::vector<array>& /* outputs */) {
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI."); throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
} }
void SVD::eval_gpu( void SVD::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& /* inputs */,
std::vector<array>& outputs) { std::vector<array>& /* outputs */) {
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
} }
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) { void Inverse::eval_gpu(
const std::vector<array>& /* inputs */,
array& /* output */) {
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
} }
void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) { void Cholesky::eval_gpu(
const std::vector<array>& /* inputs */,
array& /* out */) {
throw std::runtime_error( throw std::runtime_error(
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
} }
void Eig::eval_gpu( void Eig::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& /* inputs */,
std::vector<array>& outputs) { std::vector<array>& /* outputs */) {
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
} }
void Eigh::eval_gpu( void Eigh::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& /* inputs */,
std::vector<array>& outputs) { std::vector<array>& /* outputs */) {
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
} }
void LUF::eval_gpu( void LUF::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& /* inputs */,
std::vector<array>& outputs) { std::vector<array>& /* outputs */) {
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
} }

View File

@@ -291,7 +291,7 @@ void init_reduce(
const std::string& op_name, const std::string& op_name,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& /* s */) {
auto [_, out_type] = remap_reduce_types(out, op_name); auto [_, out_type] = remap_reduce_types(out, op_name);
const std::string func_name = "init_reduce"; const std::string func_name = "init_reduce";
std::string kname = func_name; std::string kname = func_name;
@@ -397,7 +397,7 @@ void row_reduce_small(
RowReduceArgs& args, RowReduceArgs& args,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& /* s */) {
// Set the kernel // Set the kernel
int n = get_kernel_reduce_ndim(args.reduce_ndim); int n = get_kernel_reduce_ndim(args.reduce_ndim);
auto [in_type, out_type] = remap_reduce_types(in, op_name); auto [in_type, out_type] = remap_reduce_types(in, op_name);
@@ -453,7 +453,7 @@ void row_reduce_simple(
RowReduceArgs& args, RowReduceArgs& args,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& /* s */) {
// Set the kernel // Set the kernel
auto [in_type, out_type] = remap_reduce_types(in, op_name); auto [in_type, out_type] = remap_reduce_types(in, op_name);
const std::string func_name = "row_reduce_simple"; const std::string func_name = "row_reduce_simple";
@@ -493,7 +493,7 @@ void row_reduce_looped(
RowReduceArgs& args, RowReduceArgs& args,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& /* s */) {
auto [in_type, out_type] = remap_reduce_types(in, op_name); auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Set the kernel // Set the kernel
@@ -570,7 +570,7 @@ void strided_reduce_small(
ColReduceArgs& args, ColReduceArgs& args,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& /* s */) {
auto [in_type, out_type] = remap_reduce_types(in, op_name); auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Figure out the grid dims // Figure out the grid dims
@@ -747,7 +747,7 @@ void strided_reduce_looped(
ColReduceArgs& args, ColReduceArgs& args,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& /* s */) {
auto [in_type, out_type] = remap_reduce_types(in, op_name); auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
@@ -959,7 +959,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Continue with reduction operation // Continue with reduction operation
// Minimum of 4 bytes since we use size 4 structs for all reduce // Minimum of 4 bytes since we use size 4 structs for all reduce
// and metal will complain o/w // and metal will complain o/w
size_t min_bytes = std::max(out.nbytes(), 4ul); size_t min_bytes = std::max<int64_t>(out.nbytes(), 4);
out.set_data(allocator::malloc(min_bytes)); out.set_data(allocator::malloc(min_bytes));
std::string op_name; std::string op_name;
switch (reduce_type_) { switch (reduce_type_) {

View File

@@ -80,7 +80,7 @@ void ResidencySet::resize(size_t size) {
// Remove wired allocations until under capacity // Remove wired allocations until under capacity
auto allocations = wired_set_->allAllocations(); auto allocations = wired_set_->allAllocations();
auto num_allocations = wired_set_->allocationCount(); auto num_allocations = wired_set_->allocationCount();
for (int i = 0; i < num_allocations && current_size > size; ++i) { for (size_t i = 0; i < num_allocations && current_size > size; ++i) {
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i)); auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
wired_set_->removeAllocation(buf); wired_set_->removeAllocation(buf);
current_size -= buf->allocatedSize(); current_size -= buf->allocatedSize();

View File

@@ -33,7 +33,7 @@ void concatenate_gpu(
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto concurrent_ctx = compute_encoder.start_concurrent(); auto concurrent_ctx = compute_encoder.start_concurrent();
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < std::ssize(inputs); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i]; size_t data_offset = strides[axis] * sizes[i];
out_slice.copy_shared_buffer( out_slice.copy_shared_buffer(

View File

@@ -29,6 +29,10 @@ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
std::ostringstream label; std::ostringstream label;
label << "Stream " << index; label << "Stream " << index;
queue->setLabel(make_string(label)); queue->setLabel(make_string(label));
#else
// appease warnings
(void)queue;
(void)index;
#endif #endif
} }
@@ -42,6 +46,9 @@ inline void debug_set_primitive_buffer_label(
} }
label << primitive.name(); label << primitive.name();
command_buffer->setLabel(make_string(label)); command_buffer->setLabel(make_string(label));
#else
(void)command_buffer;
(void)primitive;
#endif #endif
} }