Real and Imag (#1490)

* real and imag

* fix

* fix
This commit is contained in:
Awni Hannun 2024-10-15 16:23:15 -07:00 committed by GitHub
parent 2b8ace6a03
commit 3f86399922
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 275 additions and 46 deletions

View File

@ -80,6 +80,7 @@ Operations
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
@ -125,6 +126,7 @@ Operations
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat

View File

@ -295,6 +295,13 @@ struct Floor {
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
@ -337,6 +344,13 @@ struct Negative {
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
}
}
template <typename T, typename Op>
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename Op>
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
T* dst = out.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>();
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {

View File

@ -40,18 +40,21 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
"v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);

View File

@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);

View File

@ -1,27 +1,27 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op>
template <typename T, typename U, typename Op>
[[kernel]] void unary_v(
device const T* in,
device T* out,
device U* out,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
}
template <typename T, typename Op>
template <typename T, typename U, typename Op>
[[kernel]] void unary_v2(
device const T* in,
device T* out,
device U* out,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
out[offset] = Op()(in[offset]);
}
template <typename T, typename Op, int N = 1>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void unary_g(
device const T* in,
device T* out,
device U* out,
constant const int* in_shape,
constant const size_t* in_strides,
device const int& ndim,

View File

@ -5,26 +5,30 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
instantiate_unary_all(op, float16, half) \
instantiate_unary_all(op, float32, float) \
instantiate_unary_all(op, bfloat16, bfloat16_t)
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
instantiate_unary_all(op, bool_, bool) \
instantiate_unary_all(op, uint8, uint8_t) \
instantiate_unary_all(op, uint16, uint16_t) \
instantiate_unary_all(op, uint32, uint32_t) \
instantiate_unary_all(op, uint64, uint64_t) \
instantiate_unary_all(op, int8, int8_t) \
instantiate_unary_all(op, int16, int16_t) \
instantiate_unary_all(op, int32, int32_t) \
instantiate_unary_all(op, int64, int64_t) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) \
instantiate_unary_float(op)
instantiate_unary_types(Abs)
@ -58,17 +62,19 @@ instantiate_unary_float(Tan)
instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
instantiate_unary_all(Abs, complex64, complex64_t)
instantiate_unary_all(Conjugate, complex64, complex64_t)
instantiate_unary_all(Cos, complex64, complex64_t)
instantiate_unary_all(Cosh, complex64, complex64_t)
instantiate_unary_all(Exp, complex64, complex64_t)
instantiate_unary_all(Negative, complex64, complex64_t)
instantiate_unary_all(Sign, complex64, complex64_t)
instantiate_unary_all(Sin, complex64, complex64_t)
instantiate_unary_all(Sinh, complex64, complex64_t)
instantiate_unary_all(Tan, complex64, complex64_t)
instantiate_unary_all(Tanh, complex64, complex64_t)
instantiate_unary_all(Round, complex64, complex64_t)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@ -238,6 +238,13 @@ struct Floor {
};
};
struct Imag {
template <typename T>
T operator()(T x) {
return x.imag;
};
};
struct Log {
template <typename T>
T operator()(T x) {
@ -280,6 +287,13 @@ struct Negative {
};
};
struct Real {
template <typename T>
T operator()(T x) {
return x.real;
};
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@ -16,6 +16,7 @@ MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
return d.get_kernel(kernel_name);
}

View File

@ -44,8 +44,8 @@ void unary_op_gpu_inplace(
} else {
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
}
kernel_name += "_" + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
: MTL::Size(nthreads, 1, 1);
@ -124,11 +124,13 @@ UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Floor)
UNARY_GPU(Ceil)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)

View File

@ -62,6 +62,7 @@ NO_CPU(GatherQMM)
NO_CPU(Greater)
NO_CPU(GreaterEqual)
NO_CPU(Hadamard)
NO_CPU(Imag)
NO_CPU(Less)
NO_CPU(LessEqual)
NO_CPU(Load)
@ -83,6 +84,7 @@ NO_CPU(Power)
NO_CPU_MULTI(QRF)
NO_CPU(QuantizedMatmul)
NO_CPU(RandomBits)
NO_CPU(Real)
NO_CPU(Reduce)
NO_CPU(Reshape)
NO_CPU(Round)

View File

@ -64,6 +64,7 @@ NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Imag)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
@ -85,6 +86,7 @@ NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Real)
NO_GPU(Reduce)
NO_GPU(Reshape)
NO_GPU(Round)

View File

@ -33,7 +33,8 @@ bool is_unary(const Primitive& p) {
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
typeid(p) == typeid(Expm1));
typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
typeid(p) == typeid(Imag));
}
bool is_binary(const Primitive& p) {

View File

@ -4657,4 +4657,18 @@ array roll(
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
}
array real(const array& a, StreamOrDevice s /* = {} */) {
if (!issubdtype(a.dtype(), complexfloating)) {
return a;
}
return array(a.shape(), float32, std::make_shared<Real>(to_stream(s)), {a});
}
array imag(const array& a, StreamOrDevice s /* = {} */) {
if (!issubdtype(a.dtype(), complexfloating)) {
return zeros_like(a);
}
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
}
} // namespace mlx::core

View File

@ -1480,6 +1480,12 @@ array roll(
const std::vector<int>& axes,
StreamOrDevice s = {});
/* The real part of a complex array. */
array real(const array& a, StreamOrDevice s = {});
/* The imaginary part of a complex array. */
array imag(const array& a, StreamOrDevice s = {});
/** @} */
} // namespace mlx::core

View File

@ -1797,6 +1797,36 @@ std::vector<array> GreaterEqual::jvp(
return {zeros(shape, bool_, stream())};
}
std::vector<array> Imag::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(
array(complex64_t{0.0f, -1.0f}, primals[0].dtype()),
cotangents[0],
stream())};
}
std::vector<array> Imag::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {imag(tangents[0], stream())};
}
std::pair<std::vector<array>, std::vector<int>> Imag::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{imag(inputs[0], stream())}, axes};
}
std::pair<std::vector<array>, std::vector<int>> Less::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@ -2633,6 +2663,33 @@ bool RandomBits::is_equivalent(const Primitive& other) const {
return shape_ == r_other.shape_;
}
std::vector<array> Real::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {astype(cotangents[0], primals[0].dtype(), stream())};
}
std::vector<array> Real::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {real(tangents[0], stream())};
}
std::pair<std::vector<array>, std::vector<int>> Real::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{real(inputs[0], stream())}, axes};
}
std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@ -1106,6 +1106,20 @@ class Hadamard : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Imag : public UnaryPrimitive {
public:
explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Imag)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};
class Less : public UnaryPrimitive {
public:
explicit Less(Stream stream) : UnaryPrimitive(stream) {}
@ -1561,6 +1575,20 @@ class RandomBits : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Real : public UnaryPrimitive {
public:
explicit Real(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Real)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};
class Reshape : public UnaryPrimitive {
public:
explicit Reshape(Stream stream, const std::vector<int>& shape)

View File

@ -4842,4 +4842,42 @@ void init_ops(nb::module_& m) {
axis (int or tuple(int), optional): The axis or axes along which to
roll the elements.
)pbdoc");
m.def(
"real",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::real(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Returns the real part of a complex array.
Args:
a (array): Input array.
Returns:
array: The real part of ``a``.
)pbdoc");
m.def(
"imag",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::imag(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Returns the imaginary part of a complex array.
Args:
a (array): Input array.
Returns:
array: The imaginary part of ``a``.
)pbdoc");
}

View File

@ -590,6 +590,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out1[0], out2[0]))
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
def test_complex_vjps(self):
def fun(x):
return (2.0 * mx.real(x)).sum()
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
dfdx = mx.grad(fun)(x)
self.assertTrue(mx.allclose(dfdx, 2 * mx.ones_like(x)))
def fun(x):
return (2.0 * mx.imag(x)).sum()
x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j])
dfdx = mx.grad(fun)(x)
self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x)))
if __name__ == "__main__":
unittest.main()

View File

@ -2680,6 +2680,21 @@ class TestOps(mlx_tests.MLXTestCase):
y2 = mx.roll(x, s, a)
self.assertTrue(mx.array_equal(y1, y2).item())
def test_real_imag(self):
x = mx.random.uniform(shape=(4, 4))
out = mx.real(x)
self.assertTrue(mx.array_equal(x, out))
out = mx.imag(x)
self.assertTrue(mx.array_equal(mx.zeros_like(x), out))
y = mx.random.uniform(shape=(4, 4))
z = x + 1j * y
self.assertEqual(mx.real(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.real(z), x))
self.assertEqual(mx.imag(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.imag(z), y))
if __name__ == "__main__":
unittest.main()