mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
Add Primitive::name and remove Primitive::print (#2365)
This commit is contained in:
parent
5201df5030
commit
d34f887abc
@ -138,13 +138,13 @@ more concrete:
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
|
@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
|
@ -3,16 +3,9 @@
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive) {
|
||||
std::ostringstream op_t;
|
||||
primitive->print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
std::filesystem::path current_binary_dir() {
|
||||
static std::filesystem::path binary_dir = []() {
|
||||
Dl_info info;
|
||||
|
@ -10,8 +10,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
// Return the directory that contains current shared library.
|
||||
std::filesystem::path current_binary_dir();
|
||||
|
||||
|
@ -231,7 +231,7 @@ inline void build_kernel(
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << x.primitive().name();
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
|
@ -177,7 +177,7 @@ template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
@ -291,7 +291,7 @@ template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@ -300,11 +300,11 @@ void binary_op_gpu(
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
@ -328,33 +328,31 @@ BINARY_GPU(Subtract)
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ template <typename Op>
|
||||
void binary_two_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
@ -314,7 +314,7 @@ template <typename Op>
|
||||
void binary_two_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@ -329,7 +329,7 @@ void DivMod::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -106,9 +106,7 @@ struct FusedKernelBuilder {
|
||||
value = fmt::format(
|
||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
value = ss.str();
|
||||
value = x.primitive().name();
|
||||
value += "{}(";
|
||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||
|
@ -102,7 +102,7 @@ template <typename Op>
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
if (in.size() == 0) {
|
||||
@ -178,17 +178,17 @@ template <typename Op>
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||
}
|
||||
|
||||
UNARY_GPU(Abs)
|
||||
@ -224,16 +224,15 @@ UNARY_GPU(Tanh)
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Log::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<cu::Log>(inputs, out, op, s);
|
||||
unary_op_gpu<cu::Log>(inputs, out, name(), s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<cu::Log2>(inputs, out, op, s);
|
||||
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<cu::Log10>(inputs, out, op, s);
|
||||
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -244,7 +243,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
|
||||
unary_op_gpu<cu::Round>(inputs, out, name(), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
|
@ -7,20 +7,20 @@
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
binary_op_gpu(inputs, out, name()); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
|
||||
binary_op_gpu(inputs, outputs, name()); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const array& a,
|
||||
bool large,
|
||||
int ndim,
|
||||
@ -65,7 +65,7 @@ std::string get_kernel_name(
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@ -179,7 +179,7 @@ void binary_op_gpu(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op) {
|
||||
const char* op) {
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu(inputs, outputs, op, s);
|
||||
}
|
||||
@ -187,7 +187,7 @@ void binary_op_gpu(
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
std::vector<array> outputs = {out};
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@ -209,7 +209,7 @@ void binary_op_gpu(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op) {
|
||||
const char* op) {
|
||||
auto& s = out.primitive().stream();
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -9,25 +9,25 @@ namespace mlx::core {
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -212,9 +212,7 @@ inline void build_kernel(
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
os += ss.str();
|
||||
os += x.primitive().name();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
|
@ -8,12 +8,6 @@ using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
std::ostringstream op_t;
|
||||
arr.primitive().print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
}
|
||||
|
||||
void append_binary_kernels(
|
||||
const std::string lib_name,
|
||||
const std::string& lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype type,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto t_str = get_type_string(type);
|
||||
|
@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype type,
|
||||
const std::string op);
|
||||
const char* op);
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
get_template_definition(std::string name, std::string func, Args... args) {
|
||||
std::string get_template_definition(
|
||||
std::string_view name,
|
||||
std::string_view func,
|
||||
Args... args) {
|
||||
std::ostringstream s;
|
||||
s << func << "<";
|
||||
bool first = true;
|
||||
|
@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
const char*) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ namespace mlx::core {
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
@ -128,7 +128,7 @@ void ternary_op_gpu_inplace(
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@ -141,13 +141,13 @@ void ternary_op_gpu(
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
ternary_op_gpu(inputs, out, name());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -9,13 +9,13 @@ namespace mlx::core {
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
unary_op_gpu(inputs, out, name()); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
@ -16,7 +16,7 @@ namespace mlx::core {
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
@ -98,7 +98,7 @@ void unary_op_gpu_inplace(
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace(inputs, out, op, s);
|
||||
@ -107,7 +107,7 @@ void unary_op_gpu(
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
const char* op) {
|
||||
auto& s = out.primitive().stream();
|
||||
unary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
@ -146,13 +146,13 @@ UNARY_GPU(Tanh)
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
unary_op_gpu(inputs, out, name());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
|
@ -9,13 +9,13 @@ namespace mlx::core {
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const char* op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label(
|
||||
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
|
||||
label << cbuf_label->utf8String();
|
||||
}
|
||||
primitive.print(label);
|
||||
label << primitive.name();
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ Compiled::Compiled(
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
os << a.primitive().name();
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const {
|
||||
});
|
||||
}
|
||||
|
||||
void Compiled::print(std::ostream& os) {
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
a.primitive().print(os);
|
||||
const char* Compiled::name() const {
|
||||
if (name_.empty()) {
|
||||
std::ostringstream os;
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
os << a.primitive().name();
|
||||
}
|
||||
name_ = os.str();
|
||||
}
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
||||
|
@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
const char* name() const override {
|
||||
switch (reduce_type_) {
|
||||
case And:
|
||||
os << "And";
|
||||
return "And AllReduce";
|
||||
case Or:
|
||||
os << "And";
|
||||
break;
|
||||
return "Or AllReduce";
|
||||
case Sum:
|
||||
os << "Sum";
|
||||
break;
|
||||
return "Sum AllReduce";
|
||||
case Prod:
|
||||
os << "Prod";
|
||||
break;
|
||||
return "Prod AllReduce";
|
||||
case Min:
|
||||
os << "Min";
|
||||
break;
|
||||
return "Min AllReduce";
|
||||
case Max:
|
||||
os << "Max";
|
||||
break;
|
||||
return "Max AllReduce";
|
||||
}
|
||||
os << " AllReduce";
|
||||
return "<unknwon AllReduce>";
|
||||
}
|
||||
|
||||
private:
|
||||
@ -94,7 +89,7 @@ class AllGather : public DistPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(AllGather);
|
||||
DEFINE_NAME(AllGather);
|
||||
};
|
||||
|
||||
class Send : public DistPrimitive {
|
||||
@ -110,7 +105,7 @@ class Send : public DistPrimitive {
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_PRINT(Send);
|
||||
DEFINE_NAME(Send);
|
||||
|
||||
private:
|
||||
int dst_;
|
||||
@ -126,7 +121,7 @@ class Recv : public DistPrimitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(Recv);
|
||||
DEFINE_NAME(Recv);
|
||||
|
||||
private:
|
||||
int src_;
|
||||
|
@ -354,9 +354,7 @@ struct PrimitiveFactory {
|
||||
|
||||
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
|
||||
serialize(os, p->stream());
|
||||
std::ostringstream pout;
|
||||
p->print(pout);
|
||||
auto name = pout.str();
|
||||
std::string name = p->name();
|
||||
name = name.substr(0, name.find(' '));
|
||||
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||
name = it->second;
|
||||
|
@ -58,7 +58,7 @@ class RMSNorm : public Custom {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RMSNorm)
|
||||
DEFINE_NAME(RMSNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
@ -85,7 +85,7 @@ class RMSNormVJP : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RMSNormVJP)
|
||||
DEFINE_NAME(RMSNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, eps_);
|
||||
@ -118,7 +118,7 @@ class LayerNorm : public Custom {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(LayerNorm)
|
||||
DEFINE_NAME(LayerNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
@ -144,7 +144,7 @@ class LayerNormVJP : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LayerNormVJP)
|
||||
DEFINE_NAME(LayerNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_pair(nullptr, eps_);
|
||||
@ -186,7 +186,7 @@ class RoPE : public Custom {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RoPE)
|
||||
DEFINE_NAME(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
@ -233,7 +233,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
DEFINE_PRINT(ScaledDotProductAttention);
|
||||
DEFINE_NAME(ScaledDotProductAttention);
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_tuple(nullptr, scale_, do_causal_);
|
||||
@ -263,7 +263,7 @@ class AffineQuantize : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(AffineQuantize);
|
||||
DEFINE_NAME(AffineQuantize);
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
@ -311,7 +311,7 @@ class CustomKernel : public Primitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(CustomKernel);
|
||||
DEFINE_NAME(CustomKernel);
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
|
@ -93,7 +93,7 @@ void print_graph(
|
||||
os << "\n";
|
||||
|
||||
for (auto& arr : tape) {
|
||||
arr.primitive().print(os);
|
||||
os << arr.primitive().name();
|
||||
os << " ";
|
||||
print_arrs(arr.inputs());
|
||||
os << " -> ";
|
||||
@ -143,7 +143,7 @@ void export_to_dot(
|
||||
os << "{ ";
|
||||
os << x.primitive_id();
|
||||
os << " [label =\"";
|
||||
x.primitive().print(os);
|
||||
os << x.primitive().name();
|
||||
os << "\", shape=rectangle]";
|
||||
os << "; }" << std::endl;
|
||||
// Arrows to primitive's inputs
|
||||
|
@ -500,7 +500,7 @@ array cross(
|
||||
void validate_eig(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
check_float_or_complex(a.dtype(), fname);
|
||||
|
||||
|
@ -181,7 +181,7 @@ std::vector<array> Primitive::jvp(
|
||||
const std::vector<int>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::jvp] Not implemented for ";
|
||||
print(msg);
|
||||
msg << name();
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -193,7 +193,7 @@ std::vector<array> Primitive::vjp(
|
||||
const std::vector<array>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::vjp] Not implemented for ";
|
||||
print(msg);
|
||||
msg << name();
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -203,7 +203,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
const std::vector<int>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::vmap] Not implemented for ";
|
||||
print(msg);
|
||||
msg << name();
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -211,7 +211,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::output_shapes] ";
|
||||
this->print(msg);
|
||||
msg << name();
|
||||
msg << " cannot infer output shapes.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -743,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const {
|
||||
return op_ == a_other.op_;
|
||||
}
|
||||
|
||||
void BitwiseBinary::print(std::ostream& os) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
os << "BitwiseAnd";
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
os << "BitwiseOr";
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
os << "BitwiseXor";
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
os << "LeftShift";
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
os << "RightShift";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@ -5375,8 +5355,13 @@ std::pair<std::vector<array>, std::vector<int>> View::vmap(
|
||||
return {{view(inputs[0], dtype_, stream())}, axes};
|
||||
}
|
||||
|
||||
void View::print(std::ostream& os) {
|
||||
os << "View " << dtype_;
|
||||
const char* View::name() const {
|
||||
if (name_.empty()) {
|
||||
std::ostringstream os;
|
||||
os << "View " << dtype_;
|
||||
name_ = os.str();
|
||||
}
|
||||
return name_.c_str();
|
||||
}
|
||||
|
||||
bool View::is_equivalent(const Primitive& other) const {
|
||||
|
332
mlx/primitives.h
332
mlx/primitives.h
File diff suppressed because it is too large
Load Diff
@ -33,7 +33,7 @@ class Synchronizer : public Primitive {
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
DEFINE_NAME(Synchronize);
|
||||
};
|
||||
|
||||
// Initialize the static tracing members from transforms_impl.h
|
||||
|
@ -514,7 +514,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eigh",
|
||||
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
||||
[](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) {
|
||||
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||
return nb::make_tuple(result.first, result.second);
|
||||
},
|
||||
|
@ -14,7 +14,7 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
|
||||
bool DEPRECATE(const char* old_fn, const char* new_fn) {
|
||||
std::cerr << old_fn << " is deprecated and will be removed in a future "
|
||||
<< "version. Use " << new_fn << " instead." << std::endl;
|
||||
return true;
|
||||
|
@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) {
|
||||
std::tuple<int>,
|
||||
std::pair<int, int>,
|
||||
std::vector<std::pair<int, int>>>& pad_width,
|
||||
const std::string mode,
|
||||
const std::string& mode,
|
||||
const ScalarOrArray& constant_value,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
||||
|
Loading…
Reference in New Issue
Block a user