mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (metal)
This commit is contained in:
@@ -109,7 +109,7 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Read constant / contiguous inputs in tmps
|
// Read constant / contiguous inputs in tmps
|
||||||
std::vector<array> nc_inputs;
|
std::vector<array> nc_inputs;
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||||
auto& x = inputs[i];
|
auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
@@ -134,7 +134,7 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the indices for non-contiguous inputs
|
// Initialize the indices for non-contiguous inputs
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||||
auto& xname = namer.get_name(nc_inputs[i]);
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||||
if (ndim == 1) {
|
if (ndim == 1) {
|
||||||
@@ -174,7 +174,7 @@ inline void build_kernel(
|
|||||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||||
}
|
}
|
||||||
os += " uint l = zpos % output_shape[d];\n";
|
os += " uint l = zpos % output_shape[d];\n";
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||||
auto& xname = namer.get_name(nc_inputs[i]);
|
auto& xname = namer.get_name(nc_inputs[i]);
|
||||||
os += fmt::format(" index_{0} += ", xname);
|
os += fmt::format(" index_{0} += ", xname);
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
@@ -195,7 +195,7 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read non-contiguous inputs into tmps
|
// Read non-contiguous inputs into tmps
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||||
auto& x = nc_inputs[i];
|
auto& x = nc_inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
@@ -214,7 +214,7 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
os += x.primitive().name();
|
os += x.primitive().name();
|
||||||
os += "()(";
|
os += "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
|
||||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||||
}
|
}
|
||||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||||
@@ -227,7 +227,7 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
// Increment indices and close per thread loop
|
// Increment indices and close per thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1) {
|
||||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(nc_inputs); ++i) {
|
||||||
auto& x = nc_inputs[i];
|
auto& x = nc_inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
if (!dynamic_dims) {
|
if (!dynamic_dims) {
|
||||||
@@ -396,7 +396,7 @@ void Compiled::eval_gpu(
|
|||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
int stride_idx = 1; // idx 0 is the output strides
|
int stride_idx = 1; // idx 0 is the output strides
|
||||||
Strides in_strides;
|
Strides in_strides;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||||
if (is_constant_(i)) {
|
if (is_constant_(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -990,7 +990,7 @@ void conv_3D_gpu(
|
|||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
bool flip,
|
bool flip,
|
||||||
std::vector<array>& copies) {
|
std::vector<array>& /* copies */) {
|
||||||
// Make conv params
|
// Make conv params
|
||||||
MLXConvParams<3> conv_params{
|
MLXConvParams<3> conv_params{
|
||||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ std::string write_signature(
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
constexpr int max_constant_array_size = 8;
|
constexpr int max_constant_array_size = 8;
|
||||||
// Add inputs
|
// Add inputs
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(inputs); ++i) {
|
||||||
const auto& name = input_names[i];
|
const auto& name = input_names[i];
|
||||||
const auto& arr = inputs[i];
|
const auto& arr = inputs[i];
|
||||||
auto dtype = get_type_string(arr.dtype());
|
auto dtype = get_type_string(arr.dtype());
|
||||||
@@ -109,7 +109,7 @@ std::string write_signature(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Add outputs
|
// Add outputs
|
||||||
for (int i = 0; i < output_names.size(); ++i) {
|
for (int i = 0; i < std::ssize(output_names); ++i) {
|
||||||
const auto& name = output_names[i];
|
const auto& name = output_names[i];
|
||||||
const auto& dtype = output_dtypes[i];
|
const auto& dtype = output_dtypes[i];
|
||||||
kernel_source += " device ";
|
kernel_source += " device ";
|
||||||
@@ -126,8 +126,8 @@ std::string write_signature(
|
|||||||
kernel_source += " [[buffer(";
|
kernel_source += " [[buffer(";
|
||||||
kernel_source += std::to_string(index);
|
kernel_source += std::to_string(index);
|
||||||
kernel_source += ")]]";
|
kernel_source += ")]]";
|
||||||
if (index < inputs.size() + output_names.size() - 1 ||
|
if (index < std::ssize(inputs) + std::ssize(output_names) - 1 ||
|
||||||
attributes.size() > 0) {
|
std::ssize(attributes) > 0) {
|
||||||
kernel_source += ",\n";
|
kernel_source += ",\n";
|
||||||
} else {
|
} else {
|
||||||
kernel_source += ") {\n";
|
kernel_source += ") {\n";
|
||||||
@@ -138,7 +138,7 @@ std::string write_signature(
|
|||||||
index = 0;
|
index = 0;
|
||||||
for (const auto& attr : attributes) {
|
for (const auto& attr : attributes) {
|
||||||
kernel_source += attr;
|
kernel_source += attr;
|
||||||
if (index < attributes.size() - 1) {
|
if (index < std::ssize(attributes) - 1) {
|
||||||
kernel_source += ",\n";
|
kernel_source += ",\n";
|
||||||
} else {
|
} else {
|
||||||
kernel_source += ") {\n";
|
kernel_source += ") {\n";
|
||||||
@@ -381,7 +381,7 @@ void CustomKernel::eval_gpu(
|
|||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
for (int i = 0; i < std::ssize(checked_inputs); i++) {
|
||||||
const array& in = checked_inputs[i];
|
const array& in = checked_inputs[i];
|
||||||
auto& shape_info = shape_infos_[i];
|
auto& shape_info = shape_infos_[i];
|
||||||
compute_encoder.set_input_array(in, index);
|
compute_encoder.set_input_array(in, index);
|
||||||
@@ -408,7 +408,7 @@ void CustomKernel::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto [tx, ty, tz] = threadgroup_;
|
const auto [tx, ty, tz] = threadgroup_;
|
||||||
auto tg_size = tx * ty * tz;
|
unsigned long tg_size = tx * ty * tz;
|
||||||
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (tg_size > max_tg_size) {
|
if (tg_size > max_tg_size) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
|
|||||||
@@ -127,6 +127,9 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
(void)device;
|
||||||
|
(void)lib_name;
|
||||||
#endif
|
#endif
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
@@ -713,7 +716,7 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
|
|||||||
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
||||||
|
|
||||||
std::vector<NS::Object*> objs(funcs.size());
|
std::vector<NS::Object*> objs(funcs.size());
|
||||||
for (int i = 0; i < funcs.size(); i++) {
|
for (int i = 0; i < std::ssize(funcs); i++) {
|
||||||
objs[i] = funcs[i];
|
objs[i] = funcs[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ struct DeviceStream {
|
|||||||
// Data updated between command buffers
|
// Data updated between command buffers
|
||||||
MTL::CommandBuffer* buffer{nullptr};
|
MTL::CommandBuffer* buffer{nullptr};
|
||||||
int buffer_ops{0};
|
int buffer_ops{0};
|
||||||
size_t buffer_sizes{0};
|
int64_t buffer_sizes{0};
|
||||||
|
|
||||||
// The command encoder, fence, and temporaries are updated between command
|
// The command encoder, fence, and temporaries are updated between command
|
||||||
// encoders
|
// encoders
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
auto command_buffer = d.get_command_buffer(idx);
|
auto command_buffer = d.get_command_buffer(idx);
|
||||||
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
|
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,7 +96,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||||
|
|
||||||
d.get_command_buffer(idx)->addCompletedHandler(
|
d.get_command_buffer(idx)->addCompletedHandler(
|
||||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fence::update(Stream stream, const array& x) {
|
void Fence::update(Stream stream, const array& x) {
|
||||||
@@ -124,7 +124,7 @@ void Fence::update(Stream stream, const array& x) {
|
|||||||
command_buffer->encodeSignalEvent(
|
command_buffer->encodeSignalEvent(
|
||||||
static_cast<MTL::Event*>(f.fence), f.count);
|
static_cast<MTL::Event*>(f.fence), f.count);
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +154,7 @@ void Fence::update(Stream stream, const array& x) {
|
|||||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||||
|
|
||||||
d.get_command_buffer(idx)->addCompletedHandler(
|
d.get_command_buffer(idx)->addCompletedHandler(
|
||||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
[fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ struct FourStepParams {
|
|||||||
void fft_op(
|
void fft_op(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
size_t axis,
|
int64_t axis,
|
||||||
bool inverse,
|
bool inverse,
|
||||||
bool real,
|
bool real,
|
||||||
const FourStepParams four_step_params,
|
const FourStepParams four_step_params,
|
||||||
@@ -93,7 +93,7 @@ std::vector<int> plan_stockham_fft(int n) {
|
|||||||
if (n == 1) {
|
if (n == 1) {
|
||||||
return plan;
|
return plan;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < radices.size(); i++) {
|
for (int i = 0; i < std::ssize(radices); i++) {
|
||||||
int radix = radices[i];
|
int radix = radices[i];
|
||||||
// Manually tuned radices for powers of 2
|
// Manually tuned radices for powers of 2
|
||||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||||
@@ -181,7 +181,7 @@ int compute_elems_per_thread(FFTPlan plan) {
|
|||||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||||
std::set<int> used_radices;
|
std::set<int> used_radices;
|
||||||
for (int i = 0; i < steps.size(); i++) {
|
for (int i = 0; i < std::ssize(steps); i++) {
|
||||||
int radix = radices[i % radices.size()];
|
int radix = radices[i % radices.size()];
|
||||||
if (steps[i] > 0) {
|
if (steps[i] > 0) {
|
||||||
used_radices.insert(radix);
|
used_radices.insert(radix);
|
||||||
@@ -260,7 +260,7 @@ int primitive_root(int n) {
|
|||||||
|
|
||||||
std::tuple<array, array, array> compute_raders_constants(
|
std::tuple<array, array, array> compute_raders_constants(
|
||||||
int rader_n,
|
int rader_n,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
int proot = primitive_root(rader_n);
|
int proot = primitive_root(rader_n);
|
||||||
// Fermat's little theorem
|
// Fermat's little theorem
|
||||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||||
@@ -508,7 +508,7 @@ void four_step_fft(
|
|||||||
void fft_op(
|
void fft_op(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
size_t axis,
|
int64_t axis,
|
||||||
bool inverse,
|
bool inverse,
|
||||||
bool real,
|
bool real,
|
||||||
const FourStepParams four_step_params,
|
const FourStepParams four_step_params,
|
||||||
@@ -612,11 +612,11 @@ void fft_op(
|
|||||||
|
|
||||||
// Start of radix/rader step constants
|
// Start of radix/rader step constants
|
||||||
int index = 4;
|
int index = 4;
|
||||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
for (int i = 0; i < std::ssize(plan.stockham); i++) {
|
||||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||||
index += 1;
|
index += 1;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < plan.rader.size(); i++) {
|
for (int i = 0; i < std::ssize(plan.rader); i++) {
|
||||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||||
index += 1;
|
index += 1;
|
||||||
}
|
}
|
||||||
@@ -771,8 +771,8 @@ void nd_fft_op(
|
|||||||
array temp1(temp_shape, complex64, nullptr, {});
|
array temp1(temp_shape, complex64, nullptr, {});
|
||||||
array temp2(temp_shape, complex64, nullptr, {});
|
array temp2(temp_shape, complex64, nullptr, {});
|
||||||
std::vector<array> temp_arrs = {temp1, temp2};
|
std::vector<array> temp_arrs = {temp1, temp2};
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = std::ssize(axes) - 1; i >= 0; i--) {
|
||||||
int reverse_index = axes.size() - i - 1;
|
int reverse_index = std::ssize(axes) - i - 1;
|
||||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||||
bool inplace = reverse_index >= 3 && i != 0;
|
bool inplace = reverse_index >= 3 && i != 0;
|
||||||
// Opposite order for fft vs ifft
|
// Opposite order for fft vs ifft
|
||||||
@@ -780,8 +780,8 @@ void nd_fft_op(
|
|||||||
size_t axis = axes[index];
|
size_t axis = axes[index];
|
||||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||||
// only on the final axis.
|
// only on the final axis.
|
||||||
bool step_real = (real && index == axes.size() - 1);
|
bool step_real = (real && index == std::ssize(axes) - 1);
|
||||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
const array& in_arr = i == std::ssize(axes) - 1 ? in : temp_arrs[1 - i % 2];
|
||||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ std::string gen_hadamard_codelet(int m) {
|
|||||||
while (end != std::string_view::npos) {
|
while (end != std::string_view::npos) {
|
||||||
source << " tmp[" << index << "] = ";
|
source << " tmp[" << index << "] = ";
|
||||||
auto row = matrix.substr(start, end - start);
|
auto row = matrix.substr(start, end - start);
|
||||||
for (int i = 0; i < row.length(); i++) {
|
for (int i = 0; i < std::ssize(row); i++) {
|
||||||
source << " " << row[i] << " x[" << i << "]";
|
source << " " << row[i] << " x[" << i << "]";
|
||||||
}
|
}
|
||||||
source << ";" << std::endl;
|
source << ";" << std::endl;
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
size_t slice_size = 1;
|
int64_t slice_size = 1;
|
||||||
for (auto s : slice_sizes_) {
|
for (auto s : slice_sizes_) {
|
||||||
slice_size *= s;
|
slice_size *= s;
|
||||||
}
|
}
|
||||||
@@ -94,8 +94,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
int64_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
||||||
size_t dim_y = indices.size();
|
int64_t dim_y = indices.size();
|
||||||
auto group_dims = get_block_dims(dim_x, dim_y, 1);
|
auto group_dims = get_block_dims(dim_x, dim_y, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
|
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
size_t ndim = src.ndim();
|
int64_t ndim = src.ndim();
|
||||||
|
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = fmt::format(
|
||||||
"gather{0}{1}_{2}_{3}_{4}",
|
"gather{0}{1}_{2}_{3}_{4}",
|
||||||
@@ -149,8 +149,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Launch 3D grid of threads
|
// Launch 3D grid of threads
|
||||||
// First two dimensions for the indices, the last one for the slice
|
// First two dimensions for the indices, the last one for the slice
|
||||||
size_t dim0 = 1;
|
int64_t dim0 = 1;
|
||||||
size_t dim1 = 1;
|
int64_t dim1 = 1;
|
||||||
if (nidx) {
|
if (nidx) {
|
||||||
if (inputs[1].ndim() >= 1) {
|
if (inputs[1].ndim() >= 1) {
|
||||||
dim0 = inputs[1].shape(0);
|
dim0 = inputs[1].shape(0);
|
||||||
@@ -159,13 +159,13 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dim1 = inputs[1].size() / dim0;
|
dim1 = inputs[1].size() / dim0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
size_t dim2 = slice_size;
|
int64_t dim2 = slice_size;
|
||||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||||
|
|
||||||
// Collect all idx shapes and strides into one place
|
// Collect all idx shapes and strides into one place
|
||||||
std::vector<int> idx_shapes;
|
std::vector<int> idx_shapes;
|
||||||
std::vector<size_t> idx_strides;
|
std::vector<int64_t> idx_strides;
|
||||||
std::vector<char> idx_contigs;
|
std::vector<char> idx_contigs;
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
idx_shapes.insert(
|
idx_shapes.insert(
|
||||||
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
int64_t idx_size = nidx ? inputs[1].size() : 1;
|
||||||
|
|
||||||
auto idx_to_out = idx_size / out.size();
|
auto idx_to_out = idx_size / out.size();
|
||||||
int nwork;
|
int nwork;
|
||||||
@@ -345,7 +345,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
|
|
||||||
size_t nthreads = upd.size();
|
int64_t nthreads = upd.size();
|
||||||
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -354,8 +354,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set update info
|
// Set update info
|
||||||
size_t upd_ndim = upd.ndim();
|
int64_t upd_ndim = upd.ndim();
|
||||||
size_t upd_size = 1;
|
int64_t upd_size = 1;
|
||||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||||
upd_size *= upd.shape(i);
|
upd_size *= upd.shape(i);
|
||||||
}
|
}
|
||||||
@@ -391,7 +391,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_bytes(upd_size, 6);
|
compute_encoder.set_bytes(upd_size, 6);
|
||||||
|
|
||||||
// Set output info
|
// Set output info
|
||||||
size_t out_ndim = out.ndim();
|
int64_t out_ndim = out.ndim();
|
||||||
if (out_ndim == 0) {
|
if (out_ndim == 0) {
|
||||||
// Need placeholders so Metal doesn't complain
|
// Need placeholders so Metal doesn't complain
|
||||||
int shape_ = 0;
|
int shape_ = 0;
|
||||||
@@ -448,7 +448,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
size_t ndim = src.ndim();
|
int64_t ndim = src.ndim();
|
||||||
|
|
||||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||||
|
|
||||||
@@ -486,8 +486,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Grid [size post, index size, size pre]
|
// Grid [size post, index size, size pre]
|
||||||
size_t size_pre = 1;
|
int64_t size_pre = 1;
|
||||||
size_t size_post = 1;
|
int64_t size_post = 1;
|
||||||
for (int i = 0; i < axis_; ++i) {
|
for (int i = 0; i < axis_; ++i) {
|
||||||
size_pre *= idx.shape(i);
|
size_pre *= idx.shape(i);
|
||||||
}
|
}
|
||||||
@@ -541,7 +541,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
size_t ndim = src.ndim();
|
int64_t ndim = src.ndim();
|
||||||
|
|
||||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||||
|
|
||||||
@@ -602,8 +602,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Grid [size post, index size, size pre]
|
// Grid [size post, index size, size pre]
|
||||||
size_t size_pre = 1;
|
int64_t size_pre = 1;
|
||||||
size_t size_post = 1;
|
int64_t size_post = 1;
|
||||||
for (int i = 0; i < axis_; ++i) {
|
for (int i = 0; i < axis_; ++i) {
|
||||||
size_pre *= idx.shape(i);
|
size_pre *= idx.shape(i);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -344,7 +344,7 @@ void steel_gemm_splitk_axpby(
|
|||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
int K,
|
int K,
|
||||||
int batch_size_out,
|
int /* batch_size_out */,
|
||||||
int lda,
|
int lda,
|
||||||
int ldb,
|
int ldb,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
|
|||||||
@@ -179,8 +179,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const array&,
|
const array&,
|
||||||
const std::optional<array>& mask_out,
|
const std::optional<array>& /* mask_out */,
|
||||||
const std::optional<array>& mask_op,
|
const std::optional<array>& /* mask_op */,
|
||||||
bool,
|
bool,
|
||||||
bool,
|
bool,
|
||||||
int,
|
int,
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
d.add_temporary(g, s.index);
|
d.add_temporary(g, s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
auto axis_size = x.shape().back();
|
||||||
int n_rows = x.data_size() / axis_size;
|
int n_rows = x.data_size() / axis_size;
|
||||||
|
|
||||||
// Allocate the gradient accumulator gw and a temporary to store the
|
// Allocate the gradient accumulator gw and a temporary to store the
|
||||||
@@ -246,7 +246,7 @@ void LayerNorm::eval_gpu(
|
|||||||
const array& w = inputs[1];
|
const array& w = inputs[1];
|
||||||
const array& b = inputs[2];
|
const array& b = inputs[2];
|
||||||
|
|
||||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
auto axis_size = x.shape().back();
|
||||||
int n_rows = x.data_size() / axis_size;
|
int n_rows = x.data_size() / axis_size;
|
||||||
|
|
||||||
int simd_size = 32;
|
int simd_size = 32;
|
||||||
@@ -344,7 +344,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
d.add_temporary(g, s.index);
|
d.add_temporary(g, s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
auto axis_size = x.shape().back();
|
||||||
int n_rows = x.data_size() / axis_size;
|
int n_rows = x.data_size() / axis_size;
|
||||||
|
|
||||||
// Allocate a temporary to store the gradients for w and allocate the output
|
// Allocate a temporary to store the gradients for w and allocate the output
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& /* inputs */, array& /* out */) {
|
||||||
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,41 +201,45 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void QRF::eval_gpu(
|
void QRF::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& /* inputs */,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& /* outputs */) {
|
||||||
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
|
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SVD::eval_gpu(
|
void SVD::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& /* inputs */,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& /* outputs */) {
|
||||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
void Inverse::eval_gpu(
|
||||||
|
const std::vector<array>& /* inputs */,
|
||||||
|
array& /* output */) {
|
||||||
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
|
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Cholesky::eval_gpu(
|
||||||
|
const std::vector<array>& /* inputs */,
|
||||||
|
array& /* out */) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
|
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Eig::eval_gpu(
|
void Eig::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& /* inputs */,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& /* outputs */) {
|
||||||
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
|
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Eigh::eval_gpu(
|
void Eigh::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& /* inputs */,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& /* outputs */) {
|
||||||
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
|
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void LUF::eval_gpu(
|
void LUF::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& /* inputs */,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& /* outputs */) {
|
||||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -291,7 +291,7 @@ void init_reduce(
|
|||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
auto [_, out_type] = remap_reduce_types(out, op_name);
|
auto [_, out_type] = remap_reduce_types(out, op_name);
|
||||||
const std::string func_name = "init_reduce";
|
const std::string func_name = "init_reduce";
|
||||||
std::string kname = func_name;
|
std::string kname = func_name;
|
||||||
@@ -397,7 +397,7 @@ void row_reduce_small(
|
|||||||
RowReduceArgs& args,
|
RowReduceArgs& args,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
// Set the kernel
|
// Set the kernel
|
||||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||||
@@ -453,7 +453,7 @@ void row_reduce_simple(
|
|||||||
RowReduceArgs& args,
|
RowReduceArgs& args,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
// Set the kernel
|
// Set the kernel
|
||||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||||
const std::string func_name = "row_reduce_simple";
|
const std::string func_name = "row_reduce_simple";
|
||||||
@@ -493,7 +493,7 @@ void row_reduce_looped(
|
|||||||
RowReduceArgs& args,
|
RowReduceArgs& args,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||||
|
|
||||||
// Set the kernel
|
// Set the kernel
|
||||||
@@ -570,7 +570,7 @@ void strided_reduce_small(
|
|||||||
ColReduceArgs& args,
|
ColReduceArgs& args,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||||
|
|
||||||
// Figure out the grid dims
|
// Figure out the grid dims
|
||||||
@@ -747,7 +747,7 @@ void strided_reduce_looped(
|
|||||||
ColReduceArgs& args,
|
ColReduceArgs& args,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& /* s */) {
|
||||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||||
|
|
||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
@@ -959,7 +959,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Continue with reduction operation
|
// Continue with reduction operation
|
||||||
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
// Minimum of 4 bytes since we use size 4 structs for all reduce
|
||||||
// and metal will complain o/w
|
// and metal will complain o/w
|
||||||
size_t min_bytes = std::max(out.nbytes(), 4ul);
|
size_t min_bytes = std::max<int64_t>(out.nbytes(), 4);
|
||||||
out.set_data(allocator::malloc(min_bytes));
|
out.set_data(allocator::malloc(min_bytes));
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ void ResidencySet::resize(size_t size) {
|
|||||||
// Remove wired allocations until under capacity
|
// Remove wired allocations until under capacity
|
||||||
auto allocations = wired_set_->allAllocations();
|
auto allocations = wired_set_->allAllocations();
|
||||||
auto num_allocations = wired_set_->allocationCount();
|
auto num_allocations = wired_set_->allocationCount();
|
||||||
for (int i = 0; i < num_allocations && current_size > size; ++i) {
|
for (size_t i = 0; i < num_allocations && current_size > size; ++i) {
|
||||||
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
|
||||||
wired_set_->removeAllocation(buf);
|
wired_set_->removeAllocation(buf);
|
||||||
current_size -= buf->allocatedSize();
|
current_size -= buf->allocatedSize();
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ void concatenate_gpu(
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto concurrent_ctx = compute_encoder.start_concurrent();
|
auto concurrent_ctx = compute_encoder.start_concurrent();
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
size_t data_offset = strides[axis] * sizes[i];
|
size_t data_offset = strides[axis] * sizes[i];
|
||||||
out_slice.copy_shared_buffer(
|
out_slice.copy_shared_buffer(
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
|||||||
std::ostringstream label;
|
std::ostringstream label;
|
||||||
label << "Stream " << index;
|
label << "Stream " << index;
|
||||||
queue->setLabel(make_string(label));
|
queue->setLabel(make_string(label));
|
||||||
|
#else
|
||||||
|
// appease warnings
|
||||||
|
(void)queue;
|
||||||
|
(void)index;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,6 +46,9 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
}
|
}
|
||||||
label << primitive.name();
|
label << primitive.name();
|
||||||
command_buffer->setLabel(make_string(label));
|
command_buffer->setLabel(make_string(label));
|
||||||
|
#else
|
||||||
|
(void)command_buffer;
|
||||||
|
(void)primitive;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user