mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
parent
2b8ace6a03
commit
3f86399922
@ -80,6 +80,7 @@ Operations
|
|||||||
greater_equal
|
greater_equal
|
||||||
hadamard_transform
|
hadamard_transform
|
||||||
identity
|
identity
|
||||||
|
imag
|
||||||
inner
|
inner
|
||||||
isfinite
|
isfinite
|
||||||
isclose
|
isclose
|
||||||
@ -125,6 +126,7 @@ Operations
|
|||||||
quantize
|
quantize
|
||||||
quantized_matmul
|
quantized_matmul
|
||||||
radians
|
radians
|
||||||
|
real
|
||||||
reciprocal
|
reciprocal
|
||||||
remainder
|
remainder
|
||||||
repeat
|
repeat
|
||||||
|
@ -295,6 +295,13 @@ struct Floor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Imag {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::imag(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
@ -337,6 +344,13 @@ struct Negative {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Real {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::real(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Round {
|
struct Round {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
copy(in, out, ctype);
|
copy(in, out, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||||
|
}
|
||||||
|
|
||||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||||
|
}
|
||||||
|
|
||||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename U = T, typename Op>
|
||||||
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
|
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
||||||
for (size_t i = 0; i < shape; i += 1) {
|
for (size_t i = 0; i < shape; i += 1) {
|
||||||
out[i] = op(*a);
|
out[i] = op(*a);
|
||||||
a += stride;
|
a += stride;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename U = T, typename Op>
|
||||||
void unary_op(const array& a, array& out, Op op) {
|
void unary_op(const array& a, array& out, Op op) {
|
||||||
const T* a_ptr = a.data<T>();
|
const T* a_ptr = a.data<T>();
|
||||||
if (a.flags().contiguous) {
|
if (a.flags().contiguous) {
|
||||||
set_unary_output_data(a, out);
|
set_unary_output_data(a, out);
|
||||||
T* dst = out.data<T>();
|
U* dst = out.data<U>();
|
||||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||||
dst[i] = op(a_ptr[i]);
|
dst[i] = op(a_ptr[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
T* dst = out.data<T>();
|
U* dst = out.data<U>();
|
||||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||||
if (a.ndim() <= 1) {
|
if (a.ndim() <= 1) {
|
||||||
|
@ -40,18 +40,21 @@ MTL::ComputePipelineState* get_arange_kernel(
|
|||||||
MTL::ComputePipelineState* get_unary_kernel(
|
MTL::ComputePipelineState* get_unary_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op) {
|
const std::string op) {
|
||||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
|
auto in_t = get_type_string(in_type);
|
||||||
|
auto out_t = get_type_string(out_type);
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
|
|||||||
MTL::ComputePipelineState* get_unary_kernel(
|
MTL::ComputePipelineState* get_unary_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
Dtype in_type,
|
||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op);
|
const std::string op);
|
||||||
|
|
||||||
|
@ -1,27 +1,27 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void unary_v(
|
[[kernel]] void unary_v(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device T* out,
|
device U* out,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
out[index] = Op()(in[index]);
|
out[index] = Op()(in[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void unary_v2(
|
[[kernel]] void unary_v2(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device T* out,
|
device U* out,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
out[offset] = Op()(in[offset]);
|
out[offset] = Op()(in[offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, int N = 1>
|
template <typename T, typename U, typename Op, int N = 1>
|
||||||
[[kernel]] void unary_g(
|
[[kernel]] void unary_g(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device T* out,
|
device U* out,
|
||||||
constant const int* in_shape,
|
constant const int* in_shape,
|
||||||
constant const size_t* in_strides,
|
constant const size_t* in_strides,
|
||||||
device const int& ndim,
|
device const int& ndim,
|
||||||
|
@ -5,26 +5,30 @@
|
|||||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/unary.h"
|
#include "mlx/backend/metal/kernels/unary.h"
|
||||||
|
|
||||||
#define instantiate_unary_all(op, tname, type) \
|
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||||
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||||
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||||
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4)
|
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_unary_all_same(op, tname, type) \
|
||||||
|
instantiate_unary_all(op, tname, tname, type, type)
|
||||||
|
|
||||||
#define instantiate_unary_float(op) \
|
#define instantiate_unary_float(op) \
|
||||||
instantiate_unary_all(op, float16, half) \
|
instantiate_unary_all_same(op, float16, half) \
|
||||||
instantiate_unary_all(op, float32, float) \
|
instantiate_unary_all_same(op, float32, float) \
|
||||||
instantiate_unary_all(op, bfloat16, bfloat16_t)
|
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||||
|
|
||||||
#define instantiate_unary_types(op) \
|
#define instantiate_unary_types(op) \
|
||||||
instantiate_unary_all(op, bool_, bool) \
|
instantiate_unary_all_same(op, bool_, bool) \
|
||||||
instantiate_unary_all(op, uint8, uint8_t) \
|
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||||
instantiate_unary_all(op, uint16, uint16_t) \
|
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||||
instantiate_unary_all(op, uint32, uint32_t) \
|
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||||
instantiate_unary_all(op, uint64, uint64_t) \
|
instantiate_unary_all_same(op, uint64, uint64_t) \
|
||||||
instantiate_unary_all(op, int8, int8_t) \
|
instantiate_unary_all_same(op, int8, int8_t) \
|
||||||
instantiate_unary_all(op, int16, int16_t) \
|
instantiate_unary_all_same(op, int16, int16_t) \
|
||||||
instantiate_unary_all(op, int32, int32_t) \
|
instantiate_unary_all_same(op, int32, int32_t) \
|
||||||
instantiate_unary_all(op, int64, int64_t) \
|
instantiate_unary_all_same(op, int64, int64_t) \
|
||||||
instantiate_unary_float(op)
|
instantiate_unary_float(op)
|
||||||
|
|
||||||
instantiate_unary_types(Abs)
|
instantiate_unary_types(Abs)
|
||||||
@ -58,17 +62,19 @@ instantiate_unary_float(Tan)
|
|||||||
instantiate_unary_float(Tanh)
|
instantiate_unary_float(Tanh)
|
||||||
instantiate_unary_float(Round)
|
instantiate_unary_float(Round)
|
||||||
|
|
||||||
instantiate_unary_all(Abs, complex64, complex64_t)
|
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||||
instantiate_unary_all(Conjugate, complex64, complex64_t)
|
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||||
instantiate_unary_all(Cos, complex64, complex64_t)
|
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||||
instantiate_unary_all(Sign, complex64, complex64_t)
|
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||||
instantiate_unary_all(Tanh, complex64, complex64_t)
|
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||||
instantiate_unary_all(Round, complex64, complex64_t)
|
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||||
|
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
|
||||||
|
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
|
||||||
|
|
||||||
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on
|
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||||
|
@ -238,6 +238,13 @@ struct Floor {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Imag {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return x.imag;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
@ -280,6 +287,13 @@ struct Negative {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Real {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return x.real;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
struct Round {
|
struct Round {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -16,6 +16,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
Dtype,
|
Dtype,
|
||||||
|
Dtype,
|
||||||
const std::string) {
|
const std::string) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
@ -44,8 +44,8 @@ void unary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
|
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
|
||||||
}
|
}
|
||||||
kernel_name += "_" + op + type_to_name(out);
|
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
||||||
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
||||||
|
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
@ -124,11 +124,13 @@ UNARY_GPU(Erf)
|
|||||||
UNARY_GPU(ErfInv)
|
UNARY_GPU(ErfInv)
|
||||||
UNARY_GPU(Exp)
|
UNARY_GPU(Exp)
|
||||||
UNARY_GPU(Expm1)
|
UNARY_GPU(Expm1)
|
||||||
|
UNARY_GPU(Imag)
|
||||||
UNARY_GPU(Log1p)
|
UNARY_GPU(Log1p)
|
||||||
UNARY_GPU(LogicalNot)
|
UNARY_GPU(LogicalNot)
|
||||||
UNARY_GPU(Floor)
|
UNARY_GPU(Floor)
|
||||||
UNARY_GPU(Ceil)
|
UNARY_GPU(Ceil)
|
||||||
UNARY_GPU(Negative)
|
UNARY_GPU(Negative)
|
||||||
|
UNARY_GPU(Real)
|
||||||
UNARY_GPU(Sigmoid)
|
UNARY_GPU(Sigmoid)
|
||||||
UNARY_GPU(Sign)
|
UNARY_GPU(Sign)
|
||||||
UNARY_GPU(Sin)
|
UNARY_GPU(Sin)
|
||||||
|
@ -62,6 +62,7 @@ NO_CPU(GatherQMM)
|
|||||||
NO_CPU(Greater)
|
NO_CPU(Greater)
|
||||||
NO_CPU(GreaterEqual)
|
NO_CPU(GreaterEqual)
|
||||||
NO_CPU(Hadamard)
|
NO_CPU(Hadamard)
|
||||||
|
NO_CPU(Imag)
|
||||||
NO_CPU(Less)
|
NO_CPU(Less)
|
||||||
NO_CPU(LessEqual)
|
NO_CPU(LessEqual)
|
||||||
NO_CPU(Load)
|
NO_CPU(Load)
|
||||||
@ -83,6 +84,7 @@ NO_CPU(Power)
|
|||||||
NO_CPU_MULTI(QRF)
|
NO_CPU_MULTI(QRF)
|
||||||
NO_CPU(QuantizedMatmul)
|
NO_CPU(QuantizedMatmul)
|
||||||
NO_CPU(RandomBits)
|
NO_CPU(RandomBits)
|
||||||
|
NO_CPU(Real)
|
||||||
NO_CPU(Reduce)
|
NO_CPU(Reduce)
|
||||||
NO_CPU(Reshape)
|
NO_CPU(Reshape)
|
||||||
NO_CPU(Round)
|
NO_CPU(Round)
|
||||||
|
@ -64,6 +64,7 @@ NO_GPU(GatherQMM)
|
|||||||
NO_GPU(Greater)
|
NO_GPU(Greater)
|
||||||
NO_GPU(GreaterEqual)
|
NO_GPU(GreaterEqual)
|
||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
|
NO_GPU(Imag)
|
||||||
NO_GPU(Less)
|
NO_GPU(Less)
|
||||||
NO_GPU(LessEqual)
|
NO_GPU(LessEqual)
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
@ -85,6 +86,7 @@ NO_GPU(Power)
|
|||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(RandomBits)
|
NO_GPU(RandomBits)
|
||||||
|
NO_GPU(Real)
|
||||||
NO_GPU(Reduce)
|
NO_GPU(Reduce)
|
||||||
NO_GPU(Reshape)
|
NO_GPU(Reshape)
|
||||||
NO_GPU(Round)
|
NO_GPU(Round)
|
||||||
|
@ -33,7 +33,8 @@ bool is_unary(const Primitive& p) {
|
|||||||
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
|
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
|
||||||
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
||||||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
||||||
typeid(p) == typeid(Expm1));
|
typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
|
||||||
|
typeid(p) == typeid(Imag));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_binary(const Primitive& p) {
|
bool is_binary(const Primitive& p) {
|
||||||
|
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -4657,4 +4657,18 @@ array roll(
|
|||||||
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
|
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array real(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (!issubdtype(a.dtype(), complexfloating)) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
return array(a.shape(), float32, std::make_shared<Real>(to_stream(s)), {a});
|
||||||
|
}
|
||||||
|
|
||||||
|
array imag(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (!issubdtype(a.dtype(), complexfloating)) {
|
||||||
|
return zeros_like(a);
|
||||||
|
}
|
||||||
|
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1480,6 +1480,12 @@ array roll(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/* The real part of a complex array. */
|
||||||
|
array real(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/* The imaginary part of a complex array. */
|
||||||
|
array imag(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** @} */
|
/** @} */
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1797,6 +1797,36 @@ std::vector<array> GreaterEqual::jvp(
|
|||||||
return {zeros(shape, bool_, stream())};
|
return {zeros(shape, bool_, stream())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Imag::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return {multiply(
|
||||||
|
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
|
||||||
|
cotangents[0],
|
||||||
|
stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Imag::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return {imag(tangents[0], stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Imag::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
return {{imag(inputs[0], stream())}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -2633,6 +2663,33 @@ bool RandomBits::is_equivalent(const Primitive& other) const {
|
|||||||
return shape_ == r_other.shape_;
|
return shape_ == r_other.shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Real::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return {astype(cotangents[0], primals[0].dtype(), stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Real::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return {real(tangents[0], stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Real::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
return {{real(inputs[0], stream())}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -1106,6 +1106,20 @@ class Hadamard : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Imag : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Imag)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
};
|
||||||
|
|
||||||
class Less : public UnaryPrimitive {
|
class Less : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
|
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
|
||||||
@ -1561,6 +1575,20 @@ class RandomBits : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Real : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Real(Stream stream) : UnaryPrimitive(stream) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Real)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
};
|
||||||
|
|
||||||
class Reshape : public UnaryPrimitive {
|
class Reshape : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
||||||
|
@ -4842,4 +4842,42 @@ void init_ops(nb::module_& m) {
|
|||||||
axis (int or tuple(int), optional): The axis or axes along which to
|
axis (int or tuple(int), optional): The axis or axes along which to
|
||||||
roll the elements.
|
roll the elements.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"real",
|
||||||
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::real(to_array(a), s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Returns the real part of a complex array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The real part of ``a``.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"imag",
|
||||||
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::imag(to_array(a), s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Returns the imaginary part of a complex array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The imaginary part of ``a``.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -590,6 +590,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(out1[0], out2[0]))
|
self.assertTrue(mx.allclose(out1[0], out2[0]))
|
||||||
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
|
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
|
||||||
|
|
||||||
|
def test_complex_vjps(self):
|
||||||
|
def fun(x):
|
||||||
|
return (2.0 * mx.real(x)).sum()
|
||||||
|
|
||||||
|
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||||
|
dfdx = mx.grad(fun)(x)
|
||||||
|
self.assertTrue(mx.allclose(dfdx, 2 * mx.ones_like(x)))
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return (2.0 * mx.imag(x)).sum()
|
||||||
|
|
||||||
|
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||||
|
dfdx = mx.grad(fun)(x)
|
||||||
|
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2680,6 +2680,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
y2 = mx.roll(x, s, a)
|
y2 = mx.roll(x, s, a)
|
||||||
self.assertTrue(mx.array_equal(y1, y2).item())
|
self.assertTrue(mx.array_equal(y1, y2).item())
|
||||||
|
|
||||||
|
def test_real_imag(self):
|
||||||
|
x = mx.random.uniform(shape=(4, 4))
|
||||||
|
out = mx.real(x)
|
||||||
|
self.assertTrue(mx.array_equal(x, out))
|
||||||
|
|
||||||
|
out = mx.imag(x)
|
||||||
|
self.assertTrue(mx.array_equal(mx.zeros_like(x), out))
|
||||||
|
|
||||||
|
y = mx.random.uniform(shape=(4, 4))
|
||||||
|
z = x + 1j * y
|
||||||
|
self.assertEqual(mx.real(z).dtype, mx.float32)
|
||||||
|
self.assertTrue(mx.array_equal(mx.real(z), x))
|
||||||
|
self.assertEqual(mx.imag(z).dtype, mx.float32)
|
||||||
|
self.assertTrue(mx.array_equal(mx.imag(z), y))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user