mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
This commit is contained in:
parent
5fd11c347d
commit
28eac18571
15
mlx/array.h
15
mlx/array.h
@ -121,6 +121,9 @@ class array {
|
||||
template <typename T>
|
||||
T item();
|
||||
|
||||
template <typename T>
|
||||
T item() const;
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = size_t;
|
||||
@ -454,6 +457,18 @@ T array::item() {
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
void array::init(It src) {
|
||||
set_data(allocator::malloc(size() * size_of(dtype())));
|
||||
|
@ -1,3 +1,23 @@
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_C_COMPILER}
|
||||
${CMAKE_SOURCE_DIR}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
)
|
||||
|
||||
add_dependencies(mlx compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
@ -16,6 +36,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
@ -1,44 +1,484 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
inline auto get_type_string(Dtype d) {
|
||||
switch (d) {
|
||||
case float32:
|
||||
return "float";
|
||||
case float16:
|
||||
return "half";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case bool_:
|
||||
return "bool";
|
||||
case int8:
|
||||
return "int8_t";
|
||||
case int16:
|
||||
return "int16_t";
|
||||
case int32:
|
||||
return "int32_t";
|
||||
case int64:
|
||||
return "int64_t";
|
||||
case uint8:
|
||||
return "uint8_t";
|
||||
case uint16:
|
||||
return "uint16_t";
|
||||
case uint32:
|
||||
return "uint32_t";
|
||||
case uint64:
|
||||
return "uint64_t";
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
msg << "Unsupported compilation type " << d;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_int_constant(std::ostream& os, const array& x) {
|
||||
os << x.item<T>();
|
||||
}
|
||||
|
||||
void print_constant(std::ostream& os, const array& x) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
return print_float_constant<float>(os, x);
|
||||
case float16:
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
return print_int_constant<int32_t>(os, x);
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
return print_int_constant<uint32_t>(os, x);
|
||||
case uint64:
|
||||
return print_int_constant<uint64_t>(os, x);
|
||||
case bool_:
|
||||
os << std::boolalpha << x.item<bool>();
|
||||
return;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported constant type");
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
// For scalar we shouldn't do the indexing things, just read at 0
|
||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
|
||||
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]]," << std::endl
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
|
||||
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
|
||||
<< std::endl;
|
||||
|
||||
// Extract the indices per axis to individual uints if we have arrays that
|
||||
// are broadcasted or transposed
|
||||
if (add_indices) {
|
||||
if (!dynamic_dims) {
|
||||
if (ndim == 1) {
|
||||
os << " uint index_0 = pos.x;" << std::endl;
|
||||
} else if (ndim == 2) {
|
||||
os << " uint index_0 = pos.y;" << std::endl
|
||||
<< " uint index_1 = pos.x;" << std::endl;
|
||||
} else if (ndim == 3) {
|
||||
os << " uint index_0 = pos.z;" << std::endl
|
||||
<< " uint index_1 = pos.y;" << std::endl
|
||||
<< " uint index_2 = pos.x;" << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < ndim - 2; i++) {
|
||||
os << " uint index_" << i << " = (index / uint(output_strides[" << i
|
||||
<< "])) % output_shape[" << i << "];" << std::endl;
|
||||
}
|
||||
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
|
||||
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}" << std::endl;
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Too many inputs/outputs fused in the Metal Compile "
|
||||
<< "primitive which exhausted the available argument buffers for "
|
||||
<< "the kernel. Please file an issue with the function that results "
|
||||
<< "in this error. The name of the kernel is '" << kernel_name << "'";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Just a fall-back to the original tape for now
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({inputs_[i].id(), inputs[i]});
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
trace_to_real.insert({outputs_[i].id(), outputs[i]});
|
||||
// Make the name for the kernel library
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
for (auto& a : tape_) {
|
||||
std::vector<array> p_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
p_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
// If a is an output get it from the map, otherwise create it
|
||||
// NB this is safe as long as no multi-output sub primitves are allowed
|
||||
// in Compiled
|
||||
std::vector<array> p_outputs;
|
||||
if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) {
|
||||
p_outputs.push_back(it->second);
|
||||
} else {
|
||||
p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {}));
|
||||
trace_to_real.insert({a.id(), p_outputs[0]});
|
||||
}
|
||||
a.primitive().eval_gpu(p_inputs, p_outputs);
|
||||
}
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[trace_to_real](MTL::CommandBuffer*) mutable {});
|
||||
auto lib = d.get_library(kernel_lib_);
|
||||
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::get_kernel_preamble() << std::endl;
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_" + std::to_string(i),
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_dynamic",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
kernel_source_ = kernel.str();
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
}
|
||||
|
||||
// Allocate space for the outputs
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
x.size() > 1) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<std::vector<size_t>> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
std::vector<int> shape;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (x.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
if (dynamic) {
|
||||
kernel_name += "dynamic";
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
if (!contiguous && x.size() > 1) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].size();
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
9
mlx/backend/metal/compiled_preamble.h
Normal file
9
mlx/backend/metal/compiled_preamble.h
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
}
|
@ -414,6 +414,11 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(const std::string& name) {
|
||||
auto it = library_map_.find(name);
|
||||
return (it != library_map_.end()) ? it->second : nullptr;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
|
@ -62,6 +62,8 @@ class Device {
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& source_string,
|
||||
|
221
mlx/backend/metal/kernels/binary.h
Normal file
221
mlx/backend/metal/kernels/binary.h
Normal file
@ -0,0 +1,221 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return fmod(x, y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return fmod(x, y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return fmod(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
@ -1,176 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T> bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf) ? maxval :
|
||||
(maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x && y; };
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x || y; };
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
|
4
mlx/backend/metal/kernels/compiled_preamble.h
Normal file
4
mlx/backend/metal/kernels/compiled_preamble.h
Normal file
@ -0,0 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
376
mlx/backend/metal/kernels/unary.h
Normal file
376
mlx/backend/metal/kernels/unary.h
Normal file
@ -0,0 +1,376 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::abs(x);
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::acos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::acosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::asin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::asinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::atan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::atanh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::ceil(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::cos(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::cosh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(erf(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(erfinv(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::floor(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::rint(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sinh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::rsqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::tan(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::tanh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
};
|
||||
};
|
@ -1,223 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T> T operator()(T x) { return metal::abs(x); };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T> T operator()(T x) { return metal::ceil(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T> T operator()(T x) { return metal::floor(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T> T operator()(T x) { return log1p(x); };
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T> T operator()(T x) { return !x; };
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T> T operator()(T x) { return metal::rint(x); };
|
||||
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
|
||||
template <> uint32_t operator()(uint32_t x) { return x != 0; };
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T> T operator()(T x) { return x * x; };
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tan_a - tanh_b * t1) / denom,
|
||||
(tanh_b + tan_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tanh_a + tan_b * t1) / denom,
|
||||
(tan_b - tanh_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_v(
|
||||
|
@ -12,10 +12,10 @@
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max;
|
||||
static const constant U min;
|
||||
static const constant U finite_max;
|
||||
static const constant U finite_min;
|
||||
static const constant U max = metal::numeric_limits<U>::max();
|
||||
static const constant U min = metal::numeric_limits<U>::min();
|
||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
|
28
mlx/backend/metal/make_compiled_preamble.sh
Normal file
28
mlx/backend/metal/make_compiled_preamble.sh
Normal file
@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script generates a C++ function that provides the Metal unary and binary
|
||||
# ops at runtime for use with kernel generation.
|
||||
#
|
||||
# Copyright © 2023-24 Apple Inc.
|
||||
|
||||
|
||||
OUTPUT_FILE=$1
|
||||
CC=$2
|
||||
SRCDIR=$3
|
||||
|
||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
|
||||
|
||||
cat << EOF > $OUTPUT_FILE
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$CONTENT
|
||||
)preamble";
|
||||
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
EOF
|
@ -117,16 +117,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (xs[0].ndim() > 0) {
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < xs[0].ndim(); i++) {
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (auto& x : xs) {
|
||||
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
@ -142,21 +144,31 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(xs.size());
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = xs[0].shape()[to_collapse[i]];
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= xs[0].shape()[to_collapse[i]];
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < xs.size(); j++) {
|
||||
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
|
@ -13,7 +13,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int max_compile_depth = 6;
|
||||
constexpr int max_compile_depth = 10;
|
||||
|
||||
bool is_unary(const Primitive& p) {
|
||||
return (
|
||||
|
@ -12,10 +12,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
std::string get_name(const array& x) {
|
||||
const std::string& NodeNamer::get_name(const array& x) {
|
||||
auto it = names.find(x.id());
|
||||
if (it == names.end()) {
|
||||
// Get the next name in the sequence
|
||||
@ -28,11 +25,11 @@ struct NodeNamer {
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({x.id(), name});
|
||||
return name;
|
||||
|
||||
return get_name(x);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
};
|
||||
|
||||
void depth_first_traversal(
|
||||
std::function<void(array)> callback,
|
||||
|
@ -6,6 +6,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
const std::string& get_name(const array& x);
|
||||
};
|
||||
|
||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
|
@ -473,23 +473,30 @@ class Compiled : public Primitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
|
||||
void print(std::ostream& os) override;
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
std::string metal_lib_name() const {
|
||||
return kernel_lib_;
|
||||
}
|
||||
std::string metal_lib_source() const {
|
||||
return kernel_source_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<array> inputs_;
|
||||
const std::vector<array> outputs_;
|
||||
const std::vector<array> tape_;
|
||||
const std::unordered_set<uintptr_t> constant_ids_;
|
||||
|
||||
std::string kernel_lib_;
|
||||
std::string kernel_source_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||
};
|
||||
|
||||
@ -709,9 +716,16 @@ class Equal : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Equal)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (equal_nan_) {
|
||||
os << "NanEqual";
|
||||
} else {
|
||||
os << "Equal";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool equal_nan_;
|
||||
@ -945,9 +959,22 @@ class Log : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Log)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (base_) {
|
||||
case e:
|
||||
os << "Log";
|
||||
break;
|
||||
case two:
|
||||
os << "Log2";
|
||||
break;
|
||||
case ten:
|
||||
os << "Log10";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Base base_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@ -1594,9 +1621,16 @@ class Sqrt : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sqrt)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (recip_) {
|
||||
os << "Rsqrt";
|
||||
} else {
|
||||
os << "Sqrt";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool recip_;
|
||||
|
@ -623,3 +623,63 @@ TEST_CASE("test transform compiled function") {
|
||||
CHECK(!outs[0].inputs()[0].has_primitive());
|
||||
CHECK(!outs[0].inputs()[1].has_primitive());
|
||||
}
|
||||
|
||||
TEST_CASE("test metal fusion kernel reuse") {
|
||||
if (default_device() != Device::gpu) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto cfun = compile(gelu_1);
|
||||
auto x = array({2.0f, -2.0f});
|
||||
auto y = cfun({x})[0];
|
||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||
eval(y);
|
||||
|
||||
std::string lib_name = p->metal_lib_name();
|
||||
std::string lib_source = p->metal_lib_source();
|
||||
CHECK(!lib_name.empty());
|
||||
CHECK(!lib_source.empty());
|
||||
|
||||
x = astype(reshape(arange(10), {2, 5}), float32);
|
||||
auto z = cfun({x})[0];
|
||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||
eval(z);
|
||||
|
||||
std::string lib_name_z = pz->metal_lib_name();
|
||||
std::string lib_source_z = pz->metal_lib_source();
|
||||
CHECK(!lib_name_z.empty());
|
||||
CHECK(lib_source_z.empty());
|
||||
|
||||
CHECK_EQ(lib_name, lib_name_z);
|
||||
}
|
||||
|
||||
auto add3(const std::vector<array>& xs) {
|
||||
return std::vector<array>{xs[0] + xs[0] + xs[0]};
|
||||
}
|
||||
|
||||
TEST_CASE("test metal fusion types") {
|
||||
if (default_device() != Device::gpu) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto cfun = compile(add3);
|
||||
auto x = array({2.0f, -2.0f});
|
||||
auto y = cfun({x})[0];
|
||||
auto p = std::dynamic_pointer_cast<Compiled>(y.primitive_ptr());
|
||||
eval(y);
|
||||
|
||||
std::string lib_name = p->metal_lib_name();
|
||||
std::string lib_source = p->metal_lib_source();
|
||||
CHECK(!lib_name.empty());
|
||||
CHECK(!lib_source.empty());
|
||||
|
||||
x = array({2, -2}, int32);
|
||||
auto z = cfun({x})[0];
|
||||
auto pz = std::dynamic_pointer_cast<Compiled>(z.primitive_ptr());
|
||||
eval(z);
|
||||
|
||||
std::string lib_name_z = pz->metal_lib_name();
|
||||
std::string lib_source_z = pz->metal_lib_source();
|
||||
CHECK(!lib_name_z.empty());
|
||||
CHECK(!lib_source_z.empty());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user