mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
parent
2b8ace6a03
commit
3f86399922
@ -80,6 +80,7 @@ Operations
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
imag
|
||||
inner
|
||||
isfinite
|
||||
isclose
|
||||
@ -125,6 +126,7 @@ Operations
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
real
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
|
@ -295,6 +295,13 @@ struct Floor {
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::imag(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@ -337,6 +344,13 @@ struct Negative {
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::real(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
}
|
||||
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
out[i] = op(*a);
|
||||
a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
T* dst = out.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
T* dst = out.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||
if (a.ndim() <= 1) {
|
||||
|
@ -40,18 +40,21 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||
kernel_source << get_template_definition(
|
||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
|
||||
|
@ -1,27 +1,27 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void unary_v(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device U* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = Op()(in[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void unary_v2(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device U* out,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
out[offset] = Op()(in[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N = 1>
|
||||
template <typename T, typename U, typename Op, int N = 1>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device U* out,
|
||||
constant const int* in_shape,
|
||||
constant const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
|
@ -5,26 +5,30 @@
|
||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
#define instantiate_unary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
|
||||
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4)
|
||||
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all(op, float16, half) \
|
||||
instantiate_unary_all(op, float32, float) \
|
||||
instantiate_unary_all(op, bfloat16, bfloat16_t)
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all(op, bool_, bool) \
|
||||
instantiate_unary_all(op, uint8, uint8_t) \
|
||||
instantiate_unary_all(op, uint16, uint16_t) \
|
||||
instantiate_unary_all(op, uint32, uint32_t) \
|
||||
instantiate_unary_all(op, uint64, uint64_t) \
|
||||
instantiate_unary_all(op, int8, int8_t) \
|
||||
instantiate_unary_all(op, int16, int16_t) \
|
||||
instantiate_unary_all(op, int32, int32_t) \
|
||||
instantiate_unary_all(op, int64, int64_t) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
instantiate_unary_all_same(op, uint32, uint32_t) \
|
||||
instantiate_unary_all_same(op, uint64, uint64_t) \
|
||||
instantiate_unary_all_same(op, int8, int8_t) \
|
||||
instantiate_unary_all_same(op, int16, int16_t) \
|
||||
instantiate_unary_all_same(op, int32, int32_t) \
|
||||
instantiate_unary_all_same(op, int64, int64_t) \
|
||||
instantiate_unary_float(op)
|
||||
|
||||
instantiate_unary_types(Abs)
|
||||
@ -58,17 +62,19 @@ instantiate_unary_float(Tan)
|
||||
instantiate_unary_float(Tanh)
|
||||
instantiate_unary_float(Round)
|
||||
|
||||
instantiate_unary_all(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all(Round, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
|
||||
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
|
||||
|
||||
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on
|
||||
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on
|
||||
|
@ -238,6 +238,13 @@ struct Floor {
|
||||
};
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x.imag;
|
||||
};
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@ -280,6 +287,13 @@ struct Negative {
|
||||
};
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x.real;
|
||||
};
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -16,6 +16,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype,
|
||||
Dtype,
|
||||
const std::string) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
@ -44,8 +44,8 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
|
||||
}
|
||||
kernel_name += "_" + op + type_to_name(out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
||||
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
|
||||
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
|
||||
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
@ -124,11 +124,13 @@ UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
|
@ -62,6 +62,7 @@ NO_CPU(GatherQMM)
|
||||
NO_CPU(Greater)
|
||||
NO_CPU(GreaterEqual)
|
||||
NO_CPU(Hadamard)
|
||||
NO_CPU(Imag)
|
||||
NO_CPU(Less)
|
||||
NO_CPU(LessEqual)
|
||||
NO_CPU(Load)
|
||||
@ -83,6 +84,7 @@ NO_CPU(Power)
|
||||
NO_CPU_MULTI(QRF)
|
||||
NO_CPU(QuantizedMatmul)
|
||||
NO_CPU(RandomBits)
|
||||
NO_CPU(Real)
|
||||
NO_CPU(Reduce)
|
||||
NO_CPU(Reshape)
|
||||
NO_CPU(Round)
|
||||
|
@ -64,6 +64,7 @@ NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Imag)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
@ -85,6 +86,7 @@ NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Real)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Reshape)
|
||||
NO_GPU(Round)
|
||||
|
@ -33,7 +33,8 @@ bool is_unary(const Primitive& p) {
|
||||
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
|
||||
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
||||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
||||
typeid(p) == typeid(Expm1));
|
||||
typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
|
||||
typeid(p) == typeid(Imag));
|
||||
}
|
||||
|
||||
bool is_binary(const Primitive& p) {
|
||||
|
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -4657,4 +4657,18 @@ array roll(
|
||||
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
|
||||
}
|
||||
|
||||
array real(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (!issubdtype(a.dtype(), complexfloating)) {
|
||||
return a;
|
||||
}
|
||||
return array(a.shape(), float32, std::make_shared<Real>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array imag(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (!issubdtype(a.dtype(), complexfloating)) {
|
||||
return zeros_like(a);
|
||||
}
|
||||
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1480,6 +1480,12 @@ array roll(
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/* The real part of a complex array. */
|
||||
array real(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/* The imaginary part of a complex array. */
|
||||
array imag(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1797,6 +1797,36 @@ std::vector<array> GreaterEqual::jvp(
|
||||
return {zeros(shape, bool_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Imag::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {multiply(
|
||||
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
|
||||
cotangents[0],
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Imag::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {imag(tangents[0], stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Imag::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{imag(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@ -2633,6 +2663,33 @@ bool RandomBits::is_equivalent(const Primitive& other) const {
|
||||
return shape_ == r_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<array> Real::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {astype(cotangents[0], primals[0].dtype(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Real::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {real(tangents[0], stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Real::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{real(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@ -1106,6 +1106,20 @@ class Hadamard : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Imag : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Imag)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
};
|
||||
|
||||
class Less : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
|
||||
@ -1561,6 +1575,20 @@ class RandomBits : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Real : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Real(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Real)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
};
|
||||
|
||||
class Reshape : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
||||
|
@ -4842,4 +4842,42 @@ void init_ops(nb::module_& m) {
|
||||
axis (int or tuple(int), optional): The axis or axes along which to
|
||||
roll the elements.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"real",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
return mlx::core::real(to_array(a), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Returns the real part of a complex array.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The real part of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"imag",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
return mlx::core::imag(to_array(a), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Returns the imaginary part of a complex array.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The imaginary part of ``a``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -590,6 +590,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out1[0], out2[0]))
|
||||
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
|
||||
|
||||
def test_complex_vjps(self):
|
||||
def fun(x):
|
||||
return (2.0 * mx.real(x)).sum()
|
||||
|
||||
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||
dfdx = mx.grad(fun)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, 2 * mx.ones_like(x)))
|
||||
|
||||
def fun(x):
|
||||
return (2.0 * mx.imag(x)).sum()
|
||||
|
||||
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
|
||||
dfdx = mx.grad(fun)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -2680,6 +2680,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
y2 = mx.roll(x, s, a)
|
||||
self.assertTrue(mx.array_equal(y1, y2).item())
|
||||
|
||||
def test_real_imag(self):
|
||||
x = mx.random.uniform(shape=(4, 4))
|
||||
out = mx.real(x)
|
||||
self.assertTrue(mx.array_equal(x, out))
|
||||
|
||||
out = mx.imag(x)
|
||||
self.assertTrue(mx.array_equal(mx.zeros_like(x), out))
|
||||
|
||||
y = mx.random.uniform(shape=(4, 4))
|
||||
z = x + 1j * y
|
||||
self.assertEqual(mx.real(z).dtype, mx.float32)
|
||||
self.assertTrue(mx.array_equal(mx.real(z), x))
|
||||
self.assertEqual(mx.imag(z).dtype, mx.float32)
|
||||
self.assertTrue(mx.array_equal(mx.imag(z), y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user