mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 22:51:19 +08:00
Add Primitive::name and remove Primitive::print (#2365)
This commit is contained in:
parent
5201df5030
commit
d34f887abc
@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
|
@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
|
@ -3,16 +3,9 @@
|
|||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive) {
|
|
||||||
std::ostringstream op_t;
|
|
||||||
primitive->print(op_t);
|
|
||||||
return op_t.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::filesystem::path current_binary_dir() {
|
std::filesystem::path current_binary_dir() {
|
||||||
static std::filesystem::path binary_dir = []() {
|
static std::filesystem::path binary_dir = []() {
|
||||||
Dl_info info;
|
Dl_info info;
|
||||||
|
@ -10,8 +10,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
|
||||||
|
|
||||||
// Return the directory that contains current shared library.
|
// Return the directory that contains current shared library.
|
||||||
std::filesystem::path current_binary_dir();
|
std::filesystem::path current_binary_dir();
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ inline void build_kernel(
|
|||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
os << x.primitive().name();
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
|
@ -177,7 +177,7 @@ template <typename Op>
|
|||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
@ -291,7 +291,7 @@ template <typename Op>
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@ -300,11 +300,11 @@ void binary_op_gpu(
|
|||||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_GPU(func) \
|
#define BINARY_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
auto& s = out.primitive().stream(); \
|
auto& s = out.primitive().stream(); \
|
||||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||||
}
|
}
|
||||||
|
|
||||||
BINARY_GPU(Add)
|
BINARY_GPU(Add)
|
||||||
@ -328,33 +328,31 @@ BINARY_GPU(Subtract)
|
|||||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
auto op = get_primitive_string(this);
|
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||||
} else {
|
} else {
|
||||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
auto op = get_primitive_string(this);
|
|
||||||
switch (op_) {
|
switch (op_) {
|
||||||
case BitwiseBinary::And:
|
case BitwiseBinary::And:
|
||||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
|
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Or:
|
case BitwiseBinary::Or:
|
||||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
|
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Xor:
|
case BitwiseBinary::Xor:
|
||||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
|
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::LeftShift:
|
case BitwiseBinary::LeftShift:
|
||||||
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
|
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::RightShift:
|
case BitwiseBinary::RightShift:
|
||||||
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
|
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,7 +184,7 @@ template <typename Op>
|
|||||||
void binary_two_op_gpu_inplace(
|
void binary_two_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
@ -314,7 +314,7 @@ template <typename Op>
|
|||||||
void binary_two_op_gpu(
|
void binary_two_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
std::string_view op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@ -329,7 +329,7 @@ void DivMod::eval_gpu(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||||
auto& s = outputs[0].primitive().stream();
|
auto& s = outputs[0].primitive().stream();
|
||||||
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -106,9 +106,7 @@ struct FusedKernelBuilder {
|
|||||||
value = fmt::format(
|
value = fmt::format(
|
||||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream ss;
|
value = x.primitive().name();
|
||||||
x.primitive().print(ss);
|
|
||||||
value = ss.str();
|
|
||||||
value += "{}(";
|
value += "{}(";
|
||||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||||
|
@ -102,7 +102,7 @@ template <typename Op>
|
|||||||
void unary_op_gpu_inplace(
|
void unary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.size() == 0) {
|
if (in.size() == 0) {
|
||||||
@ -178,17 +178,17 @@ template <typename Op>
|
|||||||
void unary_op_gpu(
|
void unary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
set_unary_output_data(inputs[0], out);
|
set_unary_output_data(inputs[0], out);
|
||||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNARY_GPU(func) \
|
#define UNARY_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
auto& s = out.primitive().stream(); \
|
auto& s = out.primitive().stream(); \
|
||||||
unary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
unary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_GPU(Abs)
|
UNARY_GPU(Abs)
|
||||||
@ -224,16 +224,15 @@ UNARY_GPU(Tanh)
|
|||||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("Log::eval_gpu");
|
nvtx3::scoped_range r("Log::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
auto op = get_primitive_string(this);
|
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case Base::e:
|
case Base::e:
|
||||||
unary_op_gpu<cu::Log>(inputs, out, op, s);
|
unary_op_gpu<cu::Log>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
case Base::two:
|
case Base::two:
|
||||||
unary_op_gpu<cu::Log2>(inputs, out, op, s);
|
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
case Base::ten:
|
case Base::ten:
|
||||||
unary_op_gpu<cu::Log10>(inputs, out, op, s);
|
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,7 +243,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
if (issubdtype(in.dtype(), inexact)) {
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
unary_op_gpu<cu::Round>(inputs, out, get_primitive_string(this), s);
|
unary_op_gpu<cu::Round>(inputs, out, name(), s);
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
|
@ -7,20 +7,20 @@
|
|||||||
|
|
||||||
#define BINARY_GPU(func) \
|
#define BINARY_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
binary_op_gpu(inputs, out, get_primitive_string(this)); \
|
binary_op_gpu(inputs, out, name()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_GPU_MULTI(func) \
|
#define BINARY_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
|
binary_op_gpu(inputs, outputs, name()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string get_kernel_name(
|
std::string get_kernel_name(
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const array& a,
|
const array& a,
|
||||||
bool large,
|
bool large,
|
||||||
int ndim,
|
int ndim,
|
||||||
@ -65,7 +65,7 @@ std::string get_kernel_name(
|
|||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@ -179,7 +179,7 @@ void binary_op_gpu(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string& op) {
|
const char* op) {
|
||||||
auto& s = outputs[0].primitive().stream();
|
auto& s = outputs[0].primitive().stream();
|
||||||
binary_op_gpu(inputs, outputs, op, s);
|
binary_op_gpu(inputs, outputs, op, s);
|
||||||
}
|
}
|
||||||
@ -187,7 +187,7 @@ void binary_op_gpu(
|
|||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
std::vector<array> outputs = {out};
|
std::vector<array> outputs = {out};
|
||||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||||
@ -196,7 +196,7 @@ void binary_op_gpu_inplace(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@ -209,7 +209,7 @@ void binary_op_gpu(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op) {
|
const char* op) {
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
binary_op_gpu(inputs, out, op, s);
|
binary_op_gpu(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
@ -237,19 +237,19 @@ BINARY_GPU(Subtract)
|
|||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
switch (op_) {
|
switch (op_) {
|
||||||
case BitwiseBinary::And:
|
case BitwiseBinary::And:
|
||||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
binary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Or:
|
case BitwiseBinary::Or:
|
||||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
binary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Xor:
|
case BitwiseBinary::Xor:
|
||||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
binary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::LeftShift:
|
case BitwiseBinary::LeftShift:
|
||||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
binary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::RightShift:
|
case BitwiseBinary::RightShift:
|
||||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
binary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,25 +9,25 @@ namespace mlx::core {
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -212,9 +212,7 @@ inline void build_kernel(
|
|||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
namer.get_name(x.inputs()[0]));
|
namer.get_name(x.inputs()[0]));
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream ss;
|
os += x.primitive().name();
|
||||||
x.primitive().print(ss);
|
|
||||||
os += ss.str();
|
|
||||||
os += "()(";
|
os += "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||||
|
@ -8,12 +8,6 @@ using namespace fmt::literals;
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string op_name(const array& arr) {
|
|
||||||
std::ostringstream op_t;
|
|
||||||
arr.primitive().print(op_t);
|
|
||||||
return op_t.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_arange_kernel(
|
MTL::ComputePipelineState* get_arange_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const char* op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
auto in_t = get_type_string(in_type);
|
auto in_t = get_type_string(in_type);
|
||||||
@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void append_binary_kernels(
|
void append_binary_kernels(
|
||||||
const std::string lib_name,
|
const std::string& lib_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op,
|
const char* op,
|
||||||
std::string& kernel_source) {
|
std::string& kernel_source) {
|
||||||
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||||
{"ss", "binary_ss"},
|
{"ss", "binary_ss"},
|
||||||
@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const char* op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::string kernel_source;
|
std::string kernel_source;
|
||||||
@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const char* op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::string kernel_source = metal::utils();
|
std::string kernel_source = metal::utils();
|
||||||
@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype type,
|
Dtype type,
|
||||||
const std::string op) {
|
const char* op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
auto t_str = get_type_string(type);
|
auto t_str = get_type_string(type);
|
||||||
|
@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op);
|
const char* op);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_binary_kernel(
|
MTL::ComputePipelineState* get_binary_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op);
|
const char* op);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype in_type,
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op);
|
const char* op);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_ternary_kernel(
|
MTL::ComputePipelineState* get_ternary_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype type,
|
Dtype type,
|
||||||
const std::string op);
|
const char* op);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_copy_kernel(
|
MTL::ComputePipelineState* get_copy_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
|
|
||||||
// Create a GPU kernel template definition for JIT compilation
|
// Create a GPU kernel template definition for JIT compilation
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
std::string
|
std::string get_template_definition(
|
||||||
get_template_definition(std::string name, std::string func, Args... args) {
|
std::string_view name,
|
||||||
|
std::string_view func,
|
||||||
|
Args... args) {
|
||||||
std::ostringstream s;
|
std::ostringstream s;
|
||||||
s << func << "<";
|
s << func << "<";
|
||||||
bool first = true;
|
bool first = true;
|
||||||
|
@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype,
|
Dtype,
|
||||||
Dtype,
|
Dtype,
|
||||||
const std::string) {
|
const char*) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype,
|
Dtype,
|
||||||
Dtype,
|
Dtype,
|
||||||
const std::string) {
|
const char*) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype,
|
Dtype,
|
||||||
Dtype,
|
Dtype,
|
||||||
const std::string) {
|
const char*) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype,
|
Dtype,
|
||||||
const std::string) {
|
const char*) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ namespace mlx::core {
|
|||||||
void ternary_op_gpu_inplace(
|
void ternary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@ -128,7 +128,7 @@ void ternary_op_gpu_inplace(
|
|||||||
void ternary_op_gpu(
|
void ternary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@ -141,13 +141,13 @@ void ternary_op_gpu(
|
|||||||
void ternary_op_gpu(
|
void ternary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op) {
|
const char* op) {
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
ternary_op_gpu(inputs, out, op, s);
|
ternary_op_gpu(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
ternary_op_gpu(inputs, out, get_primitive_string(this));
|
ternary_op_gpu(inputs, out, name());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -9,13 +9,13 @@ namespace mlx::core {
|
|||||||
void ternary_op_gpu(
|
void ternary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void ternary_op_gpu_inplace(
|
void ternary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
#define UNARY_GPU(func) \
|
#define UNARY_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
unary_op_gpu(inputs, out, get_primitive_string(this)); \
|
unary_op_gpu(inputs, out, name()); \
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -16,7 +16,7 @@ namespace mlx::core {
|
|||||||
void unary_op_gpu_inplace(
|
void unary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
bool contig = in.flags().contiguous;
|
bool contig = in.flags().contiguous;
|
||||||
@ -98,7 +98,7 @@ void unary_op_gpu_inplace(
|
|||||||
void unary_op_gpu(
|
void unary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
set_unary_output_data(inputs[0], out);
|
set_unary_output_data(inputs[0], out);
|
||||||
unary_op_gpu_inplace(inputs, out, op, s);
|
unary_op_gpu_inplace(inputs, out, op, s);
|
||||||
@ -107,7 +107,7 @@ void unary_op_gpu(
|
|||||||
void unary_op_gpu(
|
void unary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op) {
|
const char* op) {
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
unary_op_gpu(inputs, out, op, s);
|
unary_op_gpu(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
@ -146,13 +146,13 @@ UNARY_GPU(Tanh)
|
|||||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case Base::e:
|
case Base::e:
|
||||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
unary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
case Base::two:
|
case Base::two:
|
||||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
unary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
case Base::ten:
|
case Base::ten:
|
||||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
unary_op_gpu(inputs, out, name());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (issubdtype(in.dtype(), inexact)) {
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
unary_op_gpu(inputs, out, name());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
|
@ -9,13 +9,13 @@ namespace mlx::core {
|
|||||||
void unary_op_gpu(
|
void unary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void unary_op_gpu_inplace(
|
void unary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const char* op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
|
if (auto cbuf_label = command_buffer->label(); cbuf_label) {
|
||||||
label << cbuf_label->utf8String();
|
label << cbuf_label->utf8String();
|
||||||
}
|
}
|
||||||
primitive.print(label);
|
label << primitive.name();
|
||||||
command_buffer->setLabel(make_string(label));
|
command_buffer->setLabel(make_string(label));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ Compiled::Compiled(
|
|||||||
// name and type of output
|
// name and type of output
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||||
// computation performed
|
// computation performed
|
||||||
a.primitive().print(os);
|
os << a.primitive().name();
|
||||||
// name of inputs to the function
|
// name of inputs to the function
|
||||||
for (auto& inp : a.inputs()) {
|
for (auto& inp : a.inputs()) {
|
||||||
os << namer.get_name(inp);
|
os << namer.get_name(inp);
|
||||||
@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compiled::print(std::ostream& os) {
|
const char* Compiled::name() const {
|
||||||
os << "Compiled";
|
if (name_.empty()) {
|
||||||
for (auto& a : tape_) {
|
std::ostringstream os;
|
||||||
a.primitive().print(os);
|
os << "Compiled";
|
||||||
|
for (auto& a : tape_) {
|
||||||
|
os << a.primitive().name();
|
||||||
|
}
|
||||||
|
name_ = os.str();
|
||||||
}
|
}
|
||||||
|
return name_.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case And:
|
case And:
|
||||||
os << "And";
|
return "And AllReduce";
|
||||||
case Or:
|
case Or:
|
||||||
os << "And";
|
return "Or AllReduce";
|
||||||
break;
|
|
||||||
case Sum:
|
case Sum:
|
||||||
os << "Sum";
|
return "Sum AllReduce";
|
||||||
break;
|
|
||||||
case Prod:
|
case Prod:
|
||||||
os << "Prod";
|
return "Prod AllReduce";
|
||||||
break;
|
|
||||||
case Min:
|
case Min:
|
||||||
os << "Min";
|
return "Min AllReduce";
|
||||||
break;
|
|
||||||
case Max:
|
case Max:
|
||||||
os << "Max";
|
return "Max AllReduce";
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
os << " AllReduce";
|
return "<unknwon AllReduce>";
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -94,7 +89,7 @@ class AllGather : public DistPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
DEFINE_PRINT(AllGather);
|
DEFINE_NAME(AllGather);
|
||||||
};
|
};
|
||||||
|
|
||||||
class Send : public DistPrimitive {
|
class Send : public DistPrimitive {
|
||||||
@ -110,7 +105,7 @@ class Send : public DistPrimitive {
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
DEFINE_PRINT(Send);
|
DEFINE_NAME(Send);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int dst_;
|
int dst_;
|
||||||
@ -126,7 +121,7 @@ class Recv : public DistPrimitive {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
DEFINE_PRINT(Recv);
|
DEFINE_NAME(Recv);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int src_;
|
int src_;
|
||||||
|
@ -354,9 +354,7 @@ struct PrimitiveFactory {
|
|||||||
|
|
||||||
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
|
void save(Writer& os, const std::shared_ptr<Primitive>& p) {
|
||||||
serialize(os, p->stream());
|
serialize(os, p->stream());
|
||||||
std::ostringstream pout;
|
std::string name = p->name();
|
||||||
p->print(pout);
|
|
||||||
auto name = pout.str();
|
|
||||||
name = name.substr(0, name.find(' '));
|
name = name.substr(0, name.find(' '));
|
||||||
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||||
name = it->second;
|
name = it->second;
|
||||||
|
@ -58,7 +58,7 @@ class RMSNorm : public Custom {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
DEFINE_PRINT(RMSNorm)
|
DEFINE_NAME(RMSNorm)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ class RMSNormVJP : public Custom {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
DEFINE_PRINT(RMSNormVJP)
|
DEFINE_NAME(RMSNormVJP)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_pair(nullptr, eps_);
|
return std::make_pair(nullptr, eps_);
|
||||||
@ -118,7 +118,7 @@ class LayerNorm : public Custom {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
DEFINE_PRINT(LayerNorm)
|
DEFINE_NAME(LayerNorm)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
auto state() const {
|
auto state() const {
|
||||||
@ -144,7 +144,7 @@ class LayerNormVJP : public Custom {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
DEFINE_PRINT(LayerNormVJP)
|
DEFINE_NAME(LayerNormVJP)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_pair(nullptr, eps_);
|
return std::make_pair(nullptr, eps_);
|
||||||
@ -186,7 +186,7 @@ class RoPE : public Custom {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
DEFINE_PRINT(RoPE)
|
DEFINE_NAME(RoPE)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
auto state() const {
|
auto state() const {
|
||||||
@ -233,7 +233,7 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
DEFINE_PRINT(ScaledDotProductAttention);
|
DEFINE_NAME(ScaledDotProductAttention);
|
||||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(nullptr, scale_, do_causal_);
|
return std::make_tuple(nullptr, scale_, do_causal_);
|
||||||
@ -263,7 +263,7 @@ class AffineQuantize : public Custom {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
DEFINE_PRINT(AffineQuantize);
|
DEFINE_NAME(AffineQuantize);
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
@ -311,7 +311,7 @@ class CustomKernel : public Primitive {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
DEFINE_PRINT(CustomKernel);
|
DEFINE_NAME(CustomKernel);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string source_;
|
std::string source_;
|
||||||
|
@ -93,7 +93,7 @@ void print_graph(
|
|||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
for (auto& arr : tape) {
|
for (auto& arr : tape) {
|
||||||
arr.primitive().print(os);
|
os << arr.primitive().name();
|
||||||
os << " ";
|
os << " ";
|
||||||
print_arrs(arr.inputs());
|
print_arrs(arr.inputs());
|
||||||
os << " -> ";
|
os << " -> ";
|
||||||
@ -143,7 +143,7 @@ void export_to_dot(
|
|||||||
os << "{ ";
|
os << "{ ";
|
||||||
os << x.primitive_id();
|
os << x.primitive_id();
|
||||||
os << " [label =\"";
|
os << " [label =\"";
|
||||||
x.primitive().print(os);
|
os << x.primitive().name();
|
||||||
os << "\", shape=rectangle]";
|
os << "\", shape=rectangle]";
|
||||||
os << "; }" << std::endl;
|
os << "; }" << std::endl;
|
||||||
// Arrows to primitive's inputs
|
// Arrows to primitive's inputs
|
||||||
|
@ -500,7 +500,7 @@ array cross(
|
|||||||
void validate_eig(
|
void validate_eig(
|
||||||
const array& a,
|
const array& a,
|
||||||
const StreamOrDevice& stream,
|
const StreamOrDevice& stream,
|
||||||
const std::string fname) {
|
const std::string& fname) {
|
||||||
check_cpu_stream(stream, fname);
|
check_cpu_stream(stream, fname);
|
||||||
check_float_or_complex(a.dtype(), fname);
|
check_float_or_complex(a.dtype(), fname);
|
||||||
|
|
||||||
|
@ -181,7 +181,7 @@ std::vector<array> Primitive::jvp(
|
|||||||
const std::vector<int>&) {
|
const std::vector<int>&) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Primitive::jvp] Not implemented for ";
|
msg << "[Primitive::jvp] Not implemented for ";
|
||||||
print(msg);
|
msg << name();
|
||||||
msg << ".";
|
msg << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -193,7 +193,7 @@ std::vector<array> Primitive::vjp(
|
|||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Primitive::vjp] Not implemented for ";
|
msg << "[Primitive::vjp] Not implemented for ";
|
||||||
print(msg);
|
msg << name();
|
||||||
msg << ".";
|
msg << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -203,7 +203,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
|||||||
const std::vector<int>&) {
|
const std::vector<int>&) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Primitive::vmap] Not implemented for ";
|
msg << "[Primitive::vmap] Not implemented for ";
|
||||||
print(msg);
|
msg << name();
|
||||||
msg << ".";
|
msg << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -211,7 +211,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
|||||||
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
|
std::vector<Shape> Primitive::output_shapes(const std::vector<array>&) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Primitive::output_shapes] ";
|
msg << "[Primitive::output_shapes] ";
|
||||||
this->print(msg);
|
msg << name();
|
||||||
msg << " cannot infer output shapes.";
|
msg << " cannot infer output shapes.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -743,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const {
|
|||||||
return op_ == a_other.op_;
|
return op_ == a_other.op_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::print(std::ostream& os) {
|
|
||||||
switch (op_) {
|
|
||||||
case BitwiseBinary::And:
|
|
||||||
os << "BitwiseAnd";
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Or:
|
|
||||||
os << "BitwiseOr";
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Xor:
|
|
||||||
os << "BitwiseXor";
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::LeftShift:
|
|
||||||
os << "LeftShift";
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::RightShift:
|
|
||||||
os << "RightShift";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
|
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -5375,8 +5355,13 @@ std::pair<std::vector<array>, std::vector<int>> View::vmap(
|
|||||||
return {{view(inputs[0], dtype_, stream())}, axes};
|
return {{view(inputs[0], dtype_, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
void View::print(std::ostream& os) {
|
const char* View::name() const {
|
||||||
os << "View " << dtype_;
|
if (name_.empty()) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "View " << dtype_;
|
||||||
|
name_ = os.str();
|
||||||
|
}
|
||||||
|
return name_.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool View::is_equivalent(const Primitive& other) const {
|
bool View::is_equivalent(const Primitive& other) const {
|
||||||
|
332
mlx/primitives.h
332
mlx/primitives.h
File diff suppressed because it is too large
Load Diff
@ -33,7 +33,7 @@ class Synchronizer : public Primitive {
|
|||||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
|
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
|
||||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
|
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
|
||||||
|
|
||||||
DEFINE_PRINT(Synchronize);
|
DEFINE_NAME(Synchronize);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize the static tracing members from transforms_impl.h
|
// Initialize the static tracing members from transforms_impl.h
|
||||||
|
@ -514,7 +514,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"eigh",
|
"eigh",
|
||||||
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
[](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) {
|
||||||
auto result = mx::linalg::eigh(a, UPLO, s);
|
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||||
return nb::make_tuple(result.first, result.second);
|
return nb::make_tuple(result.first, result.second);
|
||||||
},
|
},
|
||||||
|
@ -14,7 +14,7 @@ namespace mx = mlx::core;
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
bool DEPRECATE(const std::string& old_fn, const std::string new_fn) {
|
bool DEPRECATE(const char* old_fn, const char* new_fn) {
|
||||||
std::cerr << old_fn << " is deprecated and will be removed in a future "
|
std::cerr << old_fn << " is deprecated and will be removed in a future "
|
||||||
<< "version. Use " << new_fn << " instead." << std::endl;
|
<< "version. Use " << new_fn << " instead." << std::endl;
|
||||||
return true;
|
return true;
|
||||||
|
@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) {
|
|||||||
std::tuple<int>,
|
std::tuple<int>,
|
||||||
std::pair<int, int>,
|
std::pair<int, int>,
|
||||||
std::vector<std::pair<int, int>>>& pad_width,
|
std::vector<std::pair<int, int>>>& pad_width,
|
||||||
const std::string mode,
|
const std::string& mode,
|
||||||
const ScalarOrArray& constant_value,
|
const ScalarOrArray& constant_value,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
||||||
|
Loading…
Reference in New Issue
Block a user