Real and Imag (#1490)

* real and imag

* fix

* fix
This commit is contained in:
Awni Hannun
2024-10-15 16:23:15 -07:00
committed by GitHub
parent 2b8ace6a03
commit 3f86399922
21 changed files with 275 additions and 46 deletions

View File

@@ -295,6 +295,13 @@ struct Floor {
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
@@ -337,6 +344,13 @@ struct Negative {
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
}
}
template <typename T, typename Op>
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename Op>
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
T* dst = out.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>();
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {