diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index e8bc3a83d..69cba09e9 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -233,6 +233,20 @@ void time_gather_scatter() { TIME(single_element_add); } +void time_divmod() { + auto a = random::normal({1000}); + auto b = random::normal({1000}); + eval({a, b}); + + auto divmod_fused = [&a, &b]() { return divmod(a, b); }; + TIME(divmod_fused); + + auto divmod_separate = [&a, &b]() { + return std::vector{floor_divide(a, b), remainder(a, b)}; + }; + TIME(divmod_separate); +} + int main() { std::cout << "Benchmarks for " << default_device() << std::endl; time_creation_ops(); @@ -246,4 +260,5 @@ int main() { time_matmul(); time_reductions(); time_gather_scatter(); + time_divmod(); } diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 0eff8b7c6..952f12c1e 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -36,6 +36,7 @@ Operations cosh dequantize divide + divmod equal erf erfinv diff --git a/mlx/array.cpp b/mlx/array.cpp index 2ec9214ab..cc85a497b 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -39,7 +39,7 @@ array::array(const std::complex& val, Dtype dtype /* = complex64 */) array::array( const std::vector& shape, Dtype dtype, - std::unique_ptr primitive, + std::shared_ptr primitive, const std::vector& inputs) : array_desc_(std::make_shared( shape, @@ -47,6 +47,23 @@ array::array( std::move(primitive), inputs)) {} +std::vector array::make_arrays( + const std::vector>& shapes, + const std::vector& dtypes, + std::shared_ptr primitive, + const std::vector& inputs) { + std::vector outputs; + for (int i = 0; i < shapes.size(); ++i) { + outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs)); + } + for (int i = 0; i < outputs.size(); ++i) { + auto siblings = outputs; + siblings.erase(siblings.begin() + i); + outputs[i].set_siblings(std::move(siblings), i); + } + return outputs; +} + array::array(std::initializer_list data) : array_desc_(std::make_shared( std::vector{static_cast(data.size())}, @@ -66,6 +83,8 @@ array::array( void array::detach() { array_desc_->inputs.clear(); + array_desc_->siblings.clear(); + array_desc_->position = 0; array_desc_->primitive = nullptr; } @@ -127,7 +146,7 @@ array::ArrayDesc::ArrayDesc(const std::vector& shape, Dtype dtype) array::ArrayDesc::ArrayDesc( const std::vector& shape, Dtype dtype, - std::unique_ptr primitive, + std::shared_ptr primitive, const std::vector& inputs) : shape(shape), dtype(dtype), @@ -139,10 +158,6 @@ array::ArrayDesc::ArrayDesc( } } -// Needed because the Primitive type used in array.h is incomplete and the -// compiler needs to see the call to the destructor after the type is complete. -array::ArrayDesc::~ArrayDesc() = default; - array::ArrayIterator::reference array::ArrayIterator::operator*() const { auto start = std::vector(arr.ndim(), 0); auto end = arr.shape(); diff --git a/mlx/array.h b/mlx/array.h index 8ea971347..fedfb8570 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #pragma once #include #include @@ -174,7 +173,13 @@ class array { array( const std::vector& shape, Dtype dtype, - std::unique_ptr primitive, + std::shared_ptr primitive, + const std::vector& inputs); + + static std::vector make_arrays( + const std::vector>& shapes, + const std::vector& dtypes, + std::shared_ptr primitive, const std::vector& inputs); /** A unique identifier for an array. */ @@ -182,6 +187,11 @@ class array { return reinterpret_cast(array_desc_.get()); } + /** A unique identifier for an arrays primitive. */ + std::uintptr_t primitive_id() const { + return reinterpret_cast(array_desc_->primitive.get()); + } + struct Data { allocator::Buffer buffer; deleter_t d; @@ -219,12 +229,32 @@ class array { return array_desc_->inputs; }; - /** A non-const reference to the array's inputs so that they can be used to - * edit the graph. */ - std::vector& editable_inputs() { + std::vector& inputs() { return array_desc_->inputs; } + /** The array's siblings. */ + const std::vector& siblings() const { + return array_desc_->siblings; + }; + + void set_siblings(std::vector siblings, uint16_t position) { + array_desc_->siblings = std::move(siblings); + array_desc_->position = position; + } + + /** The outputs of the array's primitive (i.e. this array and + * its siblings) in the order the primitive expects. */ + std::vector outputs() const { + auto idx = array_desc_->position; + std::vector outputs; + outputs.reserve(siblings().size() + 1); + outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx); + outputs.push_back(*this); + outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end()); + return outputs; + }; + /** Detach the array from the graph. */ void detach(); @@ -299,7 +329,7 @@ class array { std::vector strides; size_t size; Dtype dtype; - std::unique_ptr primitive{nullptr}; + std::shared_ptr primitive{nullptr}; // Indicates an array is being used in a graph transform // and should not be detached from the graph @@ -321,16 +351,19 @@ class array { Flags flags; std::vector inputs; + // An array to keep track of the siblings from a multi-output + // primitive. + std::vector siblings; + // The arrays position in the output list + uint32_t position{0}; explicit ArrayDesc(const std::vector& shape, Dtype dtype); explicit ArrayDesc( const std::vector& shape, Dtype dtype, - std::unique_ptr primitive, + std::shared_ptr primitive, const std::vector& inputs); - - ~ArrayDesc(); }; // The ArrayDesc contains the details of the materialized array including the diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 335f26990..e52a47964 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -17,6 +17,12 @@ primitive::eval(inputs, out); \ } +#define DEFAULT_MULTI(primitive) \ + void primitive::eval_cpu( \ + const std::vector& inputs, std::vector& outputs) { \ + primitive::eval(inputs, outputs); \ + } + namespace mlx::core { // Use the default implementation for the following primitives @@ -57,6 +63,7 @@ DEFAULT(Slice) DEFAULT(Sort) DEFAULT(StopGradient) DEFAULT(Transpose) +DEFAULT_MULTI(DivMod) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 4199e9181..8e3de02d3 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -6,6 +6,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" +#include "mlx/backend/common/binary_two.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -75,6 +76,61 @@ void Add::eval(const std::vector& inputs, array& out) { binary(a, b, out, [](auto x, auto y) { return x + y; }); } +void DivMod::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto integral_op = [](auto x, auto y) { + return std::make_pair(x / y, x % y); + }; + auto float_op = [](auto x, auto y) { + return std::make_pair(std::trunc(x / y), std::fmod(x, y)); + }; + switch (outputs[0].dtype()) { + case bool_: + binary_op(a, b, outputs, integral_op); + case uint8: + binary_op(a, b, outputs, integral_op); + break; + case uint16: + binary_op(a, b, outputs, integral_op); + break; + case uint32: + binary_op(a, b, outputs, integral_op); + break; + case uint64: + binary_op(a, b, outputs, integral_op); + break; + case int8: + binary_op(a, b, outputs, integral_op); + break; + case int16: + binary_op(a, b, outputs, integral_op); + break; + case int32: + binary_op(a, b, outputs, integral_op); + break; + case int64: + binary_op(a, b, outputs, integral_op); + break; + case float16: + binary_op(a, b, outputs, float_op); + break; + case float32: + binary_op(a, b, outputs, float_op); + break; + case bfloat16: + binary_op(a, b, outputs, float_op); + break; + case complex64: + // Should never get here + throw std::runtime_error("[DivMod] Complex type not supported"); + break; + } +} + void Divide::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 4b0631746..c37488559 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -73,6 +73,12 @@ struct UseDefaultBinaryOp { // Should we throw? This should normally never be called. assert(false); } + + template + void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { + // Should we throw? This should normally never be called. + assert(false); + } }; template @@ -89,6 +95,18 @@ struct DefaultVectorScalar { a++; } } + + void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { + T scalar = *b; + while (size-- > 0) { + auto dst = op(*a, scalar); + *dst_a = dst.first; + *dst_b = dst.second; + dst_a++; + dst_b++; + a++; + } + } }; template @@ -105,6 +123,18 @@ struct DefaultScalarVector { b++; } } + + void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { + T scalar = *a; + while (size-- > 0) { + auto dst = op(scalar, *b); + *dst_a = dst.first; + *dst_b = dst.second; + dst_a++; + dst_b++; + b++; + } + } }; template @@ -121,6 +151,18 @@ struct DefaultVectorVector { b++; } } + + void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { + while (size-- > 0) { + auto dst = op(*a, *b); + *dst_a = dst.first; + *dst_b = dst.second; + dst_a++; + dst_b++; + a++; + b++; + } + } }; template diff --git a/mlx/backend/common/binary_two.h b/mlx/backend/common/binary_two.h new file mode 100644 index 000000000..3468cb61e --- /dev/null +++ b/mlx/backend/common/binary_two.h @@ -0,0 +1,536 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +namespace { + +template +void binary_op_dims1( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + size_t a_idx = 0; + size_t b_idx = 0; + for (size_t i = 0; i < out_a.size(); ++i) { + auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); + dst_a[i] = dst.first; + dst_b[i] = dst.second; + a_idx += a.strides()[0]; + b_idx += b.strides()[0]; + } +} + +template +void binary_op_dims1( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op, + int stride) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + size_t a_idx = 0; + size_t b_idx = 0; + for (size_t i = 0; i < a.shape()[0]; i++) { + op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); + a_idx += a.strides()[0]; + b_idx += b.strides()[0]; + dst_a += stride; + dst_b += stride; + } +} + +template +void binary_op_dims2( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); + dst_a[out_idx] = dst.first; + dst_b[out_idx++] = dst.second; + a_idx += a.strides()[1]; + b_idx += b.strides()[1]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dims2( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op, + int stride) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + size_t a_idx = 0; + size_t b_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); + a_idx += a.strides()[1]; + b_idx += b.strides()[1]; + dst_a += stride; + dst_b += stride; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dims3( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); + dst_a[out_idx] = dst.first; + dst_b[out_idx++] = dst.second; + a_idx += a.strides()[2]; + b_idx += b.strides()[2]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dims4( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + for (size_t ii = 0; ii < a.shape()[3]; ++ii) { + auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); + dst_a[out_idx] = dst.first; + dst_b[out_idx++] = dst.second; + a_idx += a.strides()[3]; + b_idx += b.strides()[3]; + } + a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; + b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + switch (out_a.ndim()) { + case 1: + binary_op_dims1(a, b, out_a, out_b, op); + return; + case 2: + binary_op_dims2(a, b, out_a, out_b, op); + return; + case 3: + binary_op_dims3(a, b, out_a, out_b, op); + return; + case 4: + binary_op_dims4(a, b, out_a, out_b, op); + return; + } + + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + for (size_t i = 0; i < out_a.size(); i++) { + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + int b_idx = elem_to_loc(i, b.shape(), b.strides()); + std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]); + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op, + int dim, + int stride) { + // Number of dimensions to loop over for vectorized ops + switch (dim) { + case 1: + binary_op_dims1(a, b, out_a, out_b, op, stride); + return; + case 2: + binary_op_dims2(a, b, out_a, out_b, op, stride); + return; + } + + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* dst_a = out_a.data(); + U* dst_b = out_b.data(); + for (size_t i = 0; i < out_a.size(); i += stride) { + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + int b_idx = elem_to_loc(i, b.shape(), b.strides()); + op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); + dst_a += stride; + dst_b += stride; + } +} + +template < + typename T, + typename U, + typename Op, + typename OpSV, + typename OpVS, + typename OpVV> +void binary_op( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op, + OpSV opsv, + OpVS opvs, + OpVV opvv) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out_a, bopt); + set_binary_op_output_data(a, b, out_b, bopt); + + // The full computation is scalar scalar so call the base op once + if (bopt == ScalarScalar) { + std::tie(*(out_a.data()), *(out_b.data())) = + op(*a.data(), *b.data()); + return; + } + + // The full computation is scalar vector so delegate to the op + if (bopt == ScalarVector) { + opsv( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + b.data_size()); + return; + } + + // The full computation is vector scalar so delegate to the op + if (bopt == VectorScalar) { + opvs( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + a.data_size()); + return; + } + + // The full computation is vector vector so delegate to the op + if (bopt == VectorVector) { + opvv( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size()); + return; + } + + // General computation so let's try to optimize + + // Get the left-most dim such that the array is row contiguous after + auto& strides = out_a.strides(); + auto leftmost_rc_dim = [&strides](const array& arr) { + int d = arr.ndim() - 1; + for (; d >= 0 && arr.strides()[d] == strides[d]; d--) { + } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a); + auto b_rc_dim = leftmost_rc_dim(b); + + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const array& arr) { + int d = arr.ndim() - 1; + for (; d >= 0 && arr.strides()[d] == 0; d--) { + } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a); + auto b_s_dim = leftmost_s_dim(b); + + auto ndim = out_a.ndim(); + + // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = ScalarVector; + dim = d; + } + + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + size_t stride; + if (dim == 0 || strides[dim - 1] < 16) { + stride = 1; + bopt = General; + dim = ndim; + } else { + stride = strides[dim - 1]; + } + + switch (bopt) { + case VectorVector: + binary_op_dispatch_dims(a, b, out_a, out_b, opvv, dim, stride); + break; + case VectorScalar: + binary_op_dispatch_dims(a, b, out_a, out_b, opvs, dim, stride); + break; + case ScalarVector: + binary_op_dispatch_dims(a, b, out_a, out_b, opsv, dim, stride); + break; + default: + binary_op_dispatch_dims(a, b, out_a, out_b, op); + break; + } +} + +template +void binary_op( + const array& a, + const array& b, + std::vector& outputs, + Op op, + OpSV opsv, + OpVS opvs, + OpVV opvv) { + // TODO: The following mess of constexpr evaluations can probably be achieved + // with template specializations and overloading. Would it be simpler? + + if (std::is_same::value) { + if (std::is_same::value) { + if (std::is_same::value) { + // All ops are UseDefaultBinaryOp (why oh why would someone call that?) + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + DefaultScalarVector(op), + DefaultVectorScalar(op), + DefaultVectorVector(op)); + } else { + // opsv and opvs were UseDefaultBinaryOp + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + DefaultScalarVector(op), + DefaultVectorScalar(op), + opvv); + } + } else if (std::is_same::value) { + // opsv and opvv were UseDefaultBinaryOp + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + DefaultScalarVector(op), + opvs, + DefaultVectorVector(op)); + } else { + // opsv was UseDefaultBinaryOp + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + DefaultScalarVector(op), + opvs, + opvv); + } + } else if (std::is_same::value) { + if (std::is_same::value) { + // opvs and opvv were UseDefaultBinaryOp + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + opsv, + DefaultVectorScalar(op), + DefaultVectorVector(op)); + } else { + // opvs was UseDefaultBinaryOp + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + opsv, + DefaultVectorScalar(op), + opvv); + } + } else if (std::is_same::value) { + // opvv was UseDefaultBinaryOp + binary_op( + a, + b, + outputs[0], + outputs[1], + op, + opsv, + opvs, + DefaultVectorVector(op)); + } else { + // All ops provided + binary_op(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv); + } +} + +template +void binary_op( + const array& a, + const array& b, + std::vector& outputs, + Op op) { + DefaultScalarVector opsv(op); + DefaultVectorScalar opvs(op); + DefaultVectorVector opvv(op); + binary_op(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv); +} + +template +void binary( + const array& a, + const array& b, + std::vector& outputs, + Ops... ops) { + switch (outputs[0].dtype()) { + case bool_: + binary_op(a, b, outputs, ops...); + break; + case uint8: + binary_op(a, b, outputs, ops...); + break; + case uint16: + binary_op(a, b, outputs, ops...); + break; + case uint32: + binary_op(a, b, outputs, ops...); + break; + case uint64: + binary_op(a, b, outputs, ops...); + break; + case int8: + binary_op(a, b, outputs, ops...); + break; + case int16: + binary_op(a, b, outputs, ops...); + break; + case int32: + binary_op(a, b, outputs, ops...); + break; + case int64: + binary_op(a, b, outputs, ops...); + break; + case float16: + binary_op(a, b, outputs, ops...); + break; + case float32: + binary_op(a, b, outputs, ops...); + break; + case bfloat16: + binary_op(a, b, outputs, ops...); + break; + case complex64: + binary_op(a, b, outputs, ops...); + break; + } +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 917286ea7..66f224624 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -16,6 +16,12 @@ primitive::eval(inputs, out); \ } +#define DEFAULT_MULTI(primitive) \ + void primitive::eval_cpu( \ + const std::vector& inputs, std::vector& outputs) { \ + primitive::eval(inputs, outputs); \ + } + namespace mlx::core { DEFAULT(Abs) @@ -89,6 +95,7 @@ DEFAULT(Subtract) DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) +DEFAULT_MULTI(DivMod) void Matmul::eval_cpu(const std::vector& inputs, array& out) { if (out.dtype() != float32) { diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index e65430c25..3ec50fb67 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -14,6 +14,7 @@ set( "arange" "arg_reduce" "binary" + "binary_two" "conv" "copy" "gemm" diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal new file mode 100644 index 000000000..3e4dbc8f2 --- /dev/null +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -0,0 +1,259 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/bf16.h" + +struct FloorDivide { + template T operator()(T x, T y) { return x / y; } + template <> float operator()(float x, float y) { return trunc(x / y); } + template <> half operator()(half x, half y) { return trunc(x / y); } + template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); } +}; + +struct Remainder { + template T operator()(T x, T y) { return x % y; } + template <> float operator()(float x, float y) { return fmod(x, y); } + template <> half operator()(half x, half y) { return fmod(x, y); } + template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } +}; + +template +[[kernel]] void binary_op_s2s( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + c[index] = Op1()(a[0], b[0]); + d[index] = Op2()(a[0], b[0]); +} + + +template +[[kernel]] void binary_op_ss( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + c[index] = Op1()(a[0], b[0]); + d[index] = Op2()(a[0], b[0]); +} + +template +[[kernel]] void binary_op_sv( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + c[index] = Op1()(a[0], b[index]); + d[index] = Op2()(a[0], b[index]); +} + +template +[[kernel]] void binary_op_vs( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + c[index] = Op1()(a[index], b[0]); + d[index] = Op2()(a[index], b[0]); +} + +template +[[kernel]] void binary_op_vv( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + c[index] = Op1()(a[index], b[index]); + d[index] = Op2()(a[index], b[index]); +} + +template +[[kernel]] void binary_op_g_nd1( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + c[index] = Op1()(a[a_idx], b[b_idx]); + d[index] = Op2()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_op_g_nd2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + c[out_idx] = Op1()(a[a_idx], b[b_idx]); + d[out_idx] = Op2()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_op_g_nd3( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + c[out_idx] = Op1()(a[a_idx], b[b_idx]); + d[out_idx] = Op2()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_op_g_nd( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + c[out_idx] = Op1()(a[idx.x], b[idx.y]); + d[out_idx] = Op2()(a[idx.x], b[idx.y]); +} + +template +[[kernel]] void binary_op_g( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + c[out_idx] = Op1()(a[idx.x], b[idx.y]); + d[out_idx] = Op2()(a[idx.x], b[idx.y]); +} + +#define instantiate_binary(name, itype, otype, op1, op2, bopt) \ + template [[host_name(name)]] \ + [[kernel]] void binary_op_##bopt( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + uint index [[thread_position_in_grid]]); + +#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \ + template [[host_name(name "_" #dims)]] \ + [[kernel]] void binary_op_g_nd( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); + +#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \ + template [[host_name(name "_1")]] \ + [[kernel]] void binary_op_g_nd1( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const size_t& a_stride, \ + constant const size_t& b_stride, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] \ + [[kernel]] void binary_op_g_nd2( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] \ + [[kernel]] void binary_op_g_nd3( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ + instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) + + +#define instantiate_binary_g(name, itype, otype, op1, op2) \ + template [[host_name(name)]] \ + [[kernel]] void binary_op_g( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + device otype* d, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); + +#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \ + instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \ + instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \ + instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \ + instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \ + instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ + instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) + +#define instantiate_binary_float(name, op1, op2) \ + instantiate_binary_all(name, float16, half, half, op1, op2) \ + instantiate_binary_all(name, float32, float, float, op1, op2) \ + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) + +#define instantiate_binary_types(name, op1, op2) \ + instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ + instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ + instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \ + instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \ + instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \ + instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \ + instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \ + instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \ + instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \ + instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \ + instantiate_binary_float(name, op1, op2) + +instantiate_binary_types(divmod, FloorDivide, Remainder) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 5c0f2d90e..8436dd3d9 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -4,7 +4,6 @@ #include #include -#include "mlx/array.h" #include "mlx/backend/metal/device.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -54,7 +53,8 @@ std::function make_task( } auto s = arr.primitive().stream(); auto command_buffer = increment_command_buffer(s); - arr.primitive().eval_gpu(arr.inputs(), arr); + auto outputs = arr.outputs(); + arr.primitive().eval_gpu(arr.inputs(), outputs); if (p) { metal::device(s.device).end_encoding(s.index); scheduler::notify_new_task(s); @@ -62,6 +62,9 @@ std::function make_task( [s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable { if (!arr.is_tracer()) { arr.detach(); + for (auto s : arr.siblings()) { + s.detach(); + } } p->set_value(); scheduler::notify_task_completion(s); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 35f91ffcb..e8dac264b 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -19,6 +19,98 @@ namespace { static constexpr int METAL_MAX_INDEX_ARRAYS = 10; +void binary_op( + const std::vector& inputs, + std::vector& outputs, + const std::string op) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + + auto& out = outputs[0]; + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_out = strides[2]; + + std::ostringstream kname; + switch (bopt) { + case ScalarScalar: + kname << "ss"; + break; + case ScalarVector: + kname << "sv"; + break; + case VectorScalar: + kname << "vs"; + break; + case VectorVector: + kname << "vv"; + break; + case General: + kname << "g"; + break; + } + kname << op << type_to_name(a); + if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << "_" << shape.size(); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, outputs[0], 2); + set_array_buffer(compute_encoder, outputs[1], 3); + + if (bopt == General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + } + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 7); + } + + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } +} + void binary_op( const std::vector& inputs, array& out, @@ -364,6 +456,12 @@ void Divide::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "div"); } +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + binary_op(inputs, outputs, "divmod"); +} + void Remainder::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "rem"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 4c4f1a76e..90019f65b 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -2,6 +2,12 @@ #include "mlx/primitives.h" +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no GPU implementation."); \ + } + #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no GPU implementation."); \ @@ -81,5 +87,6 @@ NO_GPU(Subtract) NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) +NO_GPU_MULTI(DivMod) } // namespace mlx::core diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 7c83fdeac..1cbc5a987 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -12,13 +12,11 @@ namespace mlx::core { -using OptionalArrayRef = std::optional>; - -struct ArrayNames { +struct NodeNamer { std::unordered_map names; - std::string get_name(const array& x) { - auto it = names.find(x.id()); + std::string get_name(uintptr_t id) { + auto it = names.find(id); if (it == names.end()) { // Get the next name in the sequence // [A, B, ..., Z, AA, AB, ...] @@ -29,45 +27,42 @@ struct ArrayNames { var_num = (var_num - 1) / 26; } std::string name(letters.rbegin(), letters.rend()); - names.insert({x.id(), name}); + names.insert({id, name}); return name; } return it->second; } + + std::string get_name(const array& x) { + return get_name(x.id()); + } }; void depth_first_traversal( - std::function callback, + std::function callback, const std::vector& outputs) { - std::function recurse; + std::function recurse; std::unordered_set cache; - recurse = [&](OptionalArrayRef parent, const array& x, int input_index) { + recurse = [&](const array& x) { auto id = x.id(); if (cache.find(id) != cache.end()) { return; } cache.insert(id); - for (int i = 0; i < x.inputs().size(); i++) { - recurse(x, x.inputs()[i], i); + for (auto& s : x.siblings()) { + cache.insert(s.id()); } - callback(parent, x, input_index); + for (auto& in : x.inputs()) { + recurse(in); + } + callback(x); }; - for (auto x : outputs) { - recurse(std::nullopt, x, 0); + for (auto& o : outputs) { + recurse(o); } } -void depth_first_traversal( - std::function callback, - const std::vector& outputs) { - depth_first_traversal( - [&callback](OptionalArrayRef p, const array& x, int input_index) { - callback(x); - }, - outputs); -} - void print_graph(std::ostream& os, const std::vector& outputs) { std::vector tape; std::vector inputs; @@ -82,15 +77,11 @@ void print_graph(std::ostream& os, const std::vector& outputs) { }, outputs); - ArrayNames namer; - auto print_arr = [&namer, &os](const array& a) { - os << namer.get_name(a); - os << " [" << a.shape() << ", " << a.dtype() << "]"; - }; - - auto print_arrs = [&](const std::vector& arrs) { + NodeNamer namer; + auto print_arrs = [&namer, &os](std::vector arrs) { for (auto& arr : arrs) { - print_arr(arr); + os << namer.get_name(arr); + os << " [" << arr.shape() << ", " << arr.dtype() << "]"; if (&arr != &arrs.back()) { os << ", "; } @@ -108,7 +99,7 @@ void print_graph(std::ostream& os, const std::vector& outputs) { os << " "; print_arrs(arr.inputs()); os << " -> "; - print_arr(arr); + print_arrs(arr.outputs()); os << "\n"; } } @@ -116,26 +107,47 @@ void print_graph(std::ostream& os, const std::vector& outputs) { void export_to_dot(std::ostream& os, const std::vector& outputs) { os << "digraph {" << std::endl; - ArrayNames namer; + std::unordered_set output_set; + for (auto& o : outputs) { + output_set.insert(o.id()); + } + std::unordered_set input_set; + NodeNamer namer; depth_first_traversal( - [&namer, &os](auto parent, const array& x, int input_index) { - os << "{ "; + [&](const array& x) { if (!x.has_primitive()) { - os << "rank=source; "; + input_set.insert(x.id()); + os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl; + return; } - if (!parent) { - os << "rank=sink; "; - } - os << namer.get_name(x); + + // Node for primitive if (x.has_primitive()) { + os << "{ "; + os << namer.get_name(x.primitive_id()); os << " [label =\""; x.primitive().print(os); - os << "\"]"; + os << "\", shape=rectangle]"; + os << "; }" << std::endl; + // Arrows to primitive's inputs + for (auto& a : x.inputs()) { + os << namer.get_name(x.primitive_id()) << " -> " + << namer.get_name(a) << std::endl; + } } - os << "; }" << std::endl; - for (auto c : x.inputs()) { - os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl; + // Point outputs to their primitive + for (auto& a : x.outputs()) { + os << "{ "; + if (output_set.find(a.id()) != output_set.end()) { + os << "rank=sink; "; + } + os << namer.get_name(a); + os << "; }" << std::endl; + if (x.has_primitive()) { + os << namer.get_name(a) << " -> " + << namer.get_name(x.primitive_id()) << std::endl; + } } }, outputs); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4a2740d1f..98a2bd889 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1737,6 +1737,21 @@ array operator%(const array& a, const array& b) { return remainder(a, b); } +std::vector +divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + if (is_complex(dtype)) { + throw std::invalid_argument("[divmod] Complex type not supported."); + } + auto inputs = broadcast_arrays( + {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); + return array::make_arrays( + {inputs[0].shape(), inputs[0].shape()}, + {inputs[0].dtype(), inputs[0].dtype()}, + std::make_unique(to_stream(s)), + inputs); +} + array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = diff --git a/mlx/ops.h b/mlx/ops.h index 2b8061f64..d02d39717 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -721,6 +721,10 @@ array operator/(const array& a, const array& b); array operator/(double a, const array& b); array operator/(const array& a, double b); +/** Compute the element-wise quotient and remainder. */ +std::vector +divmod(const array& a, const array& b, StreamOrDevice s = {}); + /** Compute integer division. Equivalent to doing floor(a / x). */ array floor_divide(const array& a, const array& b, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index d2b3c23e2..bc2a2a036 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -50,7 +50,7 @@ std::tuple vmap_binary_op( } // namespace -array Primitive::jvp( +std::vector Primitive::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -59,12 +59,12 @@ array Primitive::jvp( std::vector Primitive::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { throw std::invalid_argument("Primitive's vjp not implemented."); }; -std::pair Primitive::vmap( +std::pair, std::vector> Primitive::vmap( const std::vector& inputs, const std::vector& axes) { throw std::invalid_argument("Primitive's vmap not implemented."); @@ -72,52 +72,53 @@ std::pair Primitive::vmap( std::vector Abs::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Abs::jvp( +std::vector Abs::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return multiply(tangents[0], sign(primals[0], stream()), stream()); + return {multiply(tangents[0], sign(primals[0], stream()), stream())}; } -std::pair Abs::vmap( +std::pair, std::vector> Abs::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {abs(inputs[0], stream()), axes[0]}; + return {{abs(inputs[0], stream())}, axes}; } -array Add::jvp( +std::vector Add::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - return tangents.size() > 1 ? add(tangents[0], tangents[1], stream()) - : tangents[0]; + return { + tangents.size() > 1 ? add(tangents[0], tangents[1], stream()) + : tangents[0]}; } std::vector Add::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { if (argnums.size() == 1) { - return {cotan}; + return cotangents; } else { - return {cotan, cotan}; + return {cotangents[0], cotangents[0]}; } } -std::pair Add::vmap( +std::pair, std::vector> Add::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {add(a, b, stream()), to_ax}; + return {{add(a, b, stream())}, {to_ax}}; } bool Arange::is_equivalent(const Primitive& other) const { @@ -129,12 +130,12 @@ bool Arange::is_equivalent(const Primitive& other) const { std::vector ArcCos::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ArcCos::jvp( +std::vector ArcCos::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -143,25 +144,25 @@ array ArcCos::jvp( array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); array denom = negative(rsqrt(t, stream()), stream()); - return multiply(tangents[0], denom, stream()); + return {multiply(tangents[0], denom, stream())}; } -std::pair ArcCos::vmap( +std::pair, std::vector> ArcCos::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {arccos(inputs[0], stream()), axes[0]}; + return {{arccos(inputs[0], stream())}, axes}; } std::vector ArcCosh::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ArcCosh::jvp( +std::vector ArcCosh::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -169,25 +170,25 @@ array ArcCosh::jvp( assert(argnums.size() == 1); array one = array(1., primals[0].dtype()); array t = subtract(square(primals[0], stream()), one, stream()); - return multiply(tangents[0], rsqrt(t, stream()), stream()); + return {multiply(tangents[0], rsqrt(t, stream()), stream())}; } -std::pair ArcCosh::vmap( +std::pair, std::vector> ArcCosh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {arccosh(inputs[0], stream()), axes[0]}; + return {{arccosh(inputs[0], stream())}, axes}; } std::vector ArcSin::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ArcSin::jvp( +std::vector ArcSin::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -195,25 +196,25 @@ array ArcSin::jvp( assert(argnums.size() == 1); array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); - return multiply(tangents[0], rsqrt(t, stream()), stream()); + return {multiply(tangents[0], rsqrt(t, stream()), stream())}; } -std::pair ArcSin::vmap( +std::pair, std::vector> ArcSin::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {arcsin(inputs[0], stream()), axes[0]}; + return {{arcsin(inputs[0], stream())}, axes}; } std::vector ArcSinh::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ArcSinh::jvp( +std::vector ArcSinh::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -221,25 +222,25 @@ array ArcSinh::jvp( assert(argnums.size() == 1); array one = array(1., primals[0].dtype()); array t = add(square(primals[0], stream()), one, stream()); - return multiply(tangents[0], rsqrt(t, stream()), stream()); + return {multiply(tangents[0], rsqrt(t, stream()), stream())}; } -std::pair ArcSinh::vmap( +std::pair, std::vector> ArcSinh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {arcsinh(inputs[0], stream()), axes[0]}; + return {{arcsinh(inputs[0], stream())}, axes}; } std::vector ArcTan::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ArcTan::jvp( +std::vector ArcTan::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -247,25 +248,25 @@ array ArcTan::jvp( assert(argnums.size() == 1); array one = array(1., primals[0].dtype()); array t = add(one, square(primals[0], stream()), stream()); - return divide(tangents[0], t, stream()); + return {divide(tangents[0], t, stream())}; } -std::pair ArcTan::vmap( +std::pair, std::vector> ArcTan::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {arctan(inputs[0], stream()), axes[0]}; + return {{arctan(inputs[0], stream())}, axes}; } std::vector ArcTanh::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ArcTanh::jvp( +std::vector ArcTanh::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -273,25 +274,25 @@ array ArcTanh::jvp( assert(argnums.size() == 1); array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); - return divide(tangents[0], t, stream()); + return {divide(tangents[0], t, stream())}; } -std::pair ArcTanh::vmap( +std::pair, std::vector> ArcTanh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {arctanh(inputs[0], stream()), axes[0]}; + return {{arctanh(inputs[0], stream())}, axes}; } -std::pair ArgPartition::vmap( +std::pair, std::vector> ArgPartition::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); return { - argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; + {argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } bool ArgPartition::is_equivalent(const Primitive& other) const { @@ -304,13 +305,13 @@ bool ArgReduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_; } -std::pair ArgSort::vmap( +std::pair, std::vector> ArgSort::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {argsort(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; + return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } bool ArgSort::is_equivalent(const Primitive& other) const { @@ -320,26 +321,26 @@ bool ArgSort::is_equivalent(const Primitive& other) const { std::vector AsType::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - if (cotan.dtype() != dtype_) { + if (cotangents[0].dtype() != dtype_) { throw std::invalid_argument( - "[astype] Type of cotangent does not much primal output type."); + "[astype] Type of cotangentsgent does not much primal output type."); } - return {astype(cotan, primals[0].dtype(), stream())}; + return {astype(cotangents[0], primals[0].dtype(), stream())}; } -array AsType::jvp( +std::vector AsType::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - return astype(tangents[0], dtype_, stream()); + return {astype(tangents[0], dtype_, stream())}; } -std::pair AsType::vmap( +std::pair, std::vector> AsType::vmap( const std::vector& inputs, const std::vector& axes) { - return {astype(inputs[0], dtype_, stream()), axes[0]}; + return {{astype(inputs[0], dtype_, stream())}, axes}; } bool AsType::is_equivalent(const Primitive& other) const { @@ -349,13 +350,13 @@ bool AsType::is_equivalent(const Primitive& other) const { std::vector AsStrided::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(argnums.size() == 1); // Extract the sizes and cast them to ints int grad_size = primals[0].size(); - int cotan_size = cotan.size(); + int cotangents_size = cotangents[0].size(); // Make a flat container to hold the gradients auto grad = zeros_like(primals[0], stream()); @@ -364,25 +365,25 @@ std::vector AsStrided::vjp( // Create the indices that map output to input auto idx = arange(grad_size, stream()); idx = as_strided(idx, shape_, strides_, offset_, stream()); - idx = reshape(idx, {cotan_size}, stream()); + idx = reshape(idx, {cotangents_size}, stream()); - // Reshape the cotangent for use with scatter - auto flat_cotan = reshape(cotan, {cotan_size, 1}, stream()); + // Reshape the cotangentsgent for use with scatter + auto flat_cotangents = reshape(cotangents[0], {cotangents_size, 1}, stream()); // Finally accumulate the gradients and reshape them to look like the input - grad = scatter_add(grad, idx, flat_cotan, 0, stream()); + grad = scatter_add(grad, idx, flat_cotangents, 0, stream()); grad = reshape(grad, primals[0].shape(), stream()); return {grad}; } -array AsStrided::jvp( +std::vector AsStrided::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); - return as_strided(tangents[0], shape_, strides_, offset_, stream()); + return {as_strided(tangents[0], shape_, strides_, offset_, stream())}; } bool AsStrided::is_equivalent(const Primitive& other) const { @@ -393,12 +394,13 @@ bool AsStrided::is_equivalent(const Primitive& other) const { std::vector Broadcast::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(argnums.size() == 1); - // Reduce cotan to the shape of the primal + // Reduce cotangents to the shape of the primal auto& shape = primals[0].shape(); + auto& cotan = cotangents[0]; int diff = cotan.ndim() - shape.size(); std::vector reduce_axes; for (int i = 0; i < cotan.ndim(); ++i) { @@ -411,15 +413,15 @@ std::vector Broadcast::vjp( return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())}; } -array Broadcast::jvp( +std::vector Broadcast::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(argnums.size() == 1); - return broadcast_to(tangents[0], shape_, stream()); + return {broadcast_to(tangents[0], shape_, stream())}; } -std::pair Broadcast::vmap( +std::pair, std::vector> Broadcast::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); @@ -432,7 +434,7 @@ std::pair Broadcast::vmap( ax += diff; shape_.insert(shape_.begin() + ax, in_shape[ax]); auto in = reshape(inputs[0], in_shape, stream()); - return {broadcast_to(in, shape_, stream()), ax}; + return {{broadcast_to(in, shape_, stream())}, {ax}}; } bool Broadcast::is_equivalent(const Primitive& other) const { @@ -442,32 +444,33 @@ bool Broadcast::is_equivalent(const Primitive& other) const { std::vector Ceil::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Ceil::jvp( +std::vector Ceil::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return zeros_like(primals[0], stream()); + return {zeros_like(primals[0], stream())}; } -std::pair Ceil::vmap( +std::pair, std::vector> Ceil::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {ceil(inputs[0], stream()), axes[0]}; + return {{ceil(inputs[0], stream())}, axes}; } std::vector Concatenate::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { + auto& cotan = cotangents[0]; std::vector start(cotan.ndim(), 0); std::vector stop = cotan.shape(); @@ -487,7 +490,7 @@ std::vector Concatenate::vjp( return grads; } -array Concatenate::jvp( +std::vector Concatenate::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -505,10 +508,10 @@ array Concatenate::jvp( vals.push_back(zeros_like(primals[i], stream())); } } - return concatenate(vals, axis_, stream()); + return {concatenate(vals, axis_, stream())}; } -std::pair Concatenate::vmap( +std::pair, std::vector> Concatenate::vmap( const std::vector& inputs, const std::vector& axes) { std::vector t_inputs; @@ -530,7 +533,7 @@ std::pair Concatenate::vmap( } } auto axis = axis_ + (axis_ >= out_ax); - return {concatenate(t_inputs, axis, stream()), out_ax}; + return {{concatenate(t_inputs, axis, stream())}, {out_ax}}; } bool Concatenate::is_equivalent(const Primitive& other) const { @@ -540,7 +543,7 @@ bool Concatenate::is_equivalent(const Primitive& other) const { std::vector Convolution::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 2); std::vector grads; @@ -548,6 +551,7 @@ std::vector Convolution::vjp( // Collect info auto& in = primals[0]; auto& wt = primals[1]; + auto cotan = cotangents[0]; int N = in.shape(0); int O = wt.shape(0); @@ -590,15 +594,15 @@ std::vector Convolution::vjp( patches_strides[n_spatial_dim + i] = in_padded_strides[i]; } - // Reshape cotan and weights for gemm - auto cotan_reshaped = reshape(cotan, {-1, O}, stream()); + // Reshape cotangents and weights for gemm + cotan = reshape(cotangents[0], {-1, O}, stream()); auto weight_reshaped = reshape(wt, {O, -1}, stream()); for (int a : argnums) { // Grads for input if (a == 0) { - // Gemm with cotan to get patches - auto grad_patches = matmul(cotan_reshaped, weight_reshaped, stream()); + // Gemm with cotangents to get patches + auto grad_patches = matmul(cotan, weight_reshaped, stream()); // Prepare base grad array to accumulate on int in_padded_size = in_padded_strides[0] * in_padded_shape[0]; @@ -634,10 +638,10 @@ std::vector Convolution::vjp( in, padded_axes, padding_, padding_, array(0, in.dtype()), stream()); auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, stream()); - in_patches = reshape(in_patches, {cotan_reshaped.shape(0), -1}, stream()); + in_patches = reshape(in_patches, {cotan.shape(0), -1}, stream()); - auto grad = matmul( - transpose(cotan_reshaped, {1, 0}, stream()), in_patches, stream()); + auto grad = + matmul(transpose(cotan, {1, 0}, stream()), in_patches, stream()); grad = reshape(grad, wt.shape(), stream()); grads.push_back(grad); } @@ -656,91 +660,91 @@ bool Convolution::is_equivalent(const Primitive& other) const { std::vector Copy::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return {cotan}; + return cotangents; } -array Copy::jvp( +std::vector Copy::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return tangents[0]; + return tangents; } -std::pair Copy::vmap( +std::pair, std::vector> Copy::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {copy(inputs[0], stream()), axes[0]}; + return {{copy(inputs[0], stream())}, axes}; } std::vector Cos::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return {jvp(primals, cotangents, argnums)}; } -array Cos::jvp( +std::vector Cos::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return multiply( - tangents[0], negative(sin(primals[0], stream()), stream()), stream()); + return {multiply( + tangents[0], negative(sin(primals[0], stream()), stream()), stream())}; } -std::pair Cos::vmap( +std::pair, std::vector> Cos::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {cos(inputs[0], stream()), axes[0]}; + return {{cos(inputs[0], stream())}, axes}; } std::vector Cosh::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Cosh::jvp( +std::vector Cosh::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return multiply(tangents[0], sinh(primals[0], stream()), stream()); + return {multiply(tangents[0], sinh(primals[0], stream()), stream())}; } -std::pair Cosh::vmap( +std::pair, std::vector> Cosh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {cosh(inputs[0], stream()), axes[0]}; + return {{cosh(inputs[0], stream())}, axes}; } std::vector Divide::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(divide(cotan, primals[1], stream())); + vjps.push_back(divide(cotangents[0], primals[1], stream())); } else { vjps.push_back(negative( divide( - multiply(cotan, primals[0], stream()), + multiply(cotangents[0], primals[0], stream()), square(primals[1], stream()), stream()), stream())); @@ -749,7 +753,32 @@ std::vector Divide::vjp( return vjps; } -array Divide::jvp( +std::vector DivMod::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + vjps.push_back(zeros_like(primals[arg], stream())); + } + return vjps; +} + +std::vector DivMod::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + return {zeros_like(primals[0], stream())}; +} + +std::pair, std::vector> DivMod::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {divmod(a, b, stream()), {to_ax}}; +} + +std::vector Divide::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -770,34 +799,35 @@ array Divide::jvp( if (argnums.size() > 1) { out = add(out, jvp_fun(1), stream()); } - return out; + return {out}; } -std::pair Divide::vmap( +std::pair, std::vector> Divide::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {divide(a, b, stream()), to_ax}; + return {{divide(a, b, stream())}, {to_ax}}; } std::vector Remainder::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(cotan); + vjps.push_back(cotangents[0]); } else { auto x_over_y = divide(primals[0], primals[1], stream()); x_over_y = floor(x_over_y, stream()); - vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream())); + vjps.push_back( + negative(multiply(x_over_y, cotangents[0], stream()), stream())); } } return vjps; } -array Remainder::jvp( +std::vector Remainder::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -815,26 +845,26 @@ array Remainder::jvp( if (argnums.size() > 1) { out = add(out, jvp_fun(1), stream()); } - return out; + return {out}; } -std::pair Remainder::vmap( +std::pair, std::vector> Remainder::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {remainder(a, b, stream()), to_ax}; + return {{remainder(a, b, stream())}, {to_ax}}; } -std::pair Equal::vmap( +std::pair, std::vector> Equal::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {equal(a, b, stream()), axes[0]}; + return {{equal(a, b, stream())}, axes}; } std::vector Equal::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -843,22 +873,22 @@ std::vector Equal::vjp( return vjps; } -array Equal::jvp( +std::vector Equal::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); - return zeros(shape, bool_, stream()); + return {zeros(shape, bool_, stream())}; } std::vector Erf::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Erf::jvp( +std::vector Erf::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -866,28 +896,28 @@ array Erf::jvp( assert(argnums.size() == 1); auto dtype = primals[0].dtype(); auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream()); - return multiply( + return {multiply( scale, exp(negative(square(primals[0], stream()), stream()), stream()), - stream()); + stream())}; } -std::pair Erf::vmap( +std::pair, std::vector> Erf::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {erf(inputs[0], stream()), axes[0]}; + return {{erf(inputs[0], stream())}, axes}; } std::vector ErfInv::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array ErfInv::jvp( +std::vector ErfInv::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -895,42 +925,42 @@ array ErfInv::jvp( assert(argnums.size() == 1); auto dtype = primals[0].dtype(); auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream()); - return multiply( + return {multiply( scale, exp(square(erfinv(primals[0], stream()), stream()), stream()), - stream()); + stream())}; } -std::pair ErfInv::vmap( +std::pair, std::vector> ErfInv::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {erfinv(inputs[0], stream()), axes[0]}; + return {{erfinv(inputs[0], stream())}, axes}; } std::vector Exp::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Exp::jvp( +std::vector Exp::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return multiply(tangents[0], exp(primals[0], stream()), stream()); + return {multiply(tangents[0], exp(primals[0], stream()), stream())}; } -std::pair Exp::vmap( +std::pair, std::vector> Exp::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {exp(inputs[0], stream()), axes[0]}; + return {{exp(inputs[0], stream())}, axes}; } bool FFT::is_equivalent(const Primitive& other) const { @@ -939,7 +969,7 @@ bool FFT::is_equivalent(const Primitive& other) const { real_ == r_other.real_; } -std::pair FFT::vmap( +std::pair, std::vector> FFT::vmap( const std::vector& inputs, const std::vector& axes) { auto& in = inputs[0]; @@ -956,24 +986,24 @@ std::pair FFT::vmap( } } return { - array( + {array( out_shape, real_ && inverse_ ? float32 : complex64, std::make_unique(stream(), fft_axes, inverse_, real_), - {in}), - ax}; + {in})}, + {ax}}; } std::vector FFT::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); if (real_ && inverse_) { - auto out = fft::fftn(cotan, axes, stream()); + auto out = fft::fftn(cotangents[0], axes, stream()); auto start = std::vector(out.ndim(), 0); auto stop = in.shape(); out = slice(out, start, stop, stream()); @@ -990,15 +1020,16 @@ std::vector FFT::vjp( for (auto ax : axes_) { n.push_back(in.shape()[ax]); } - return {astype(fft::fftn(cotan, n, axes, stream()), in.dtype(), stream())}; + return {astype( + fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; } else if (inverse_) { - return {fft::ifftn(cotan, axes, stream())}; + return {fft::ifftn(cotangents[0], axes, stream())}; } else { - return {fft::fftn(cotan, axes, stream())}; + return {fft::fftn(cotangents[0], axes, stream())}; } } -array FFT::jvp( +std::vector FFT::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1006,59 +1037,59 @@ array FFT::jvp( assert(argnums.size() == 1); auto& tan = tangents[0]; if (real_ & inverse_) { - return fft::irfftn(tan, stream()); + return {fft::irfftn(tan, stream())}; } else if (real_) { - return fft::rfftn(tan, stream()); + return {fft::rfftn(tan, stream())}; } else if (inverse_) { - return fft::ifftn(tan, stream()); + return {fft::ifftn(tan, stream())}; } else { - return fft::fftn(tan, stream()); + return {fft::fftn(tan, stream())}; } } std::vector Floor::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Floor::jvp( +std::vector Floor::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return zeros_like(primals[0], stream()); + return {zeros_like(primals[0], stream())}; } -std::pair Floor::vmap( +std::pair, std::vector> Floor::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {floor(inputs[0], stream()), axes[0]}; + return {{floor(inputs[0], stream())}, axes}; } std::vector Full::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return {multiply(cotan, primals[0], stream())}; + return {multiply(cotangents[0], primals[0], stream())}; } -array Full::jvp( +std::vector Full::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return tangents[0]; + return tangents; } -std::pair Full::vmap( +std::pair, std::vector> Full::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); @@ -1066,10 +1097,10 @@ std::pair Full::vmap( auto& in = inputs[0]; auto out = array(in.shape(), in.dtype(), std::make_unique(stream()), {in}); - return {out, axes[0]}; + return {{out}, axes}; } -std::pair Gather::vmap( +std::pair, std::vector> Gather::vmap( const std::vector& inputs, const std::vector& axes) { auto& src = inputs[0]; @@ -1118,12 +1149,12 @@ std::pair Gather::vmap( out_ax = max_dims + axes[0]; } } - return {gather(src, indices, gather_axes, slice_sizes, stream()), out_ax}; + return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}}; } std::vector Gather::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { if (argnums.size() > 1 || argnums[0] != 0) { throw std::invalid_argument( @@ -1131,10 +1162,10 @@ std::vector Gather::vjp( } auto src = zeros_like(primals[0], stream()); std::vector inds(primals.begin() + 1, primals.end()); - return {scatter_add(src, inds, cotan, axes_, stream())}; + return {scatter_add(src, inds, cotangents[0], axes_, stream())}; } -array Gather::jvp( +std::vector Gather::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1143,7 +1174,7 @@ array Gather::jvp( "[gather] Cannot calculate JVP with respect to indices."); } std::vector inds(primals.begin() + 1, primals.end()); - return gather(tangents[0], inds, axes_, slice_sizes_, stream()); + return {gather(tangents[0], inds, axes_, slice_sizes_, stream())}; } bool Gather::is_equivalent(const Primitive& other) const { @@ -1151,16 +1182,16 @@ bool Gather::is_equivalent(const Primitive& other) const { return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_; } -std::pair Greater::vmap( +std::pair, std::vector> Greater::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {greater(a, b, stream()), axes[0]}; + return {{greater(a, b, stream())}, axes}; } std::vector Greater::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -1169,24 +1200,24 @@ std::vector Greater::vjp( return vjps; } -array Greater::jvp( +std::vector Greater::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); - return zeros(shape, bool_, stream()); + return {zeros(shape, bool_, stream())}; } -std::pair GreaterEqual::vmap( +std::pair, std::vector> GreaterEqual::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {greater_equal(a, b, stream()), axes[0]}; + return {{greater_equal(a, b, stream())}, axes}; } std::vector GreaterEqual::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -1195,24 +1226,24 @@ std::vector GreaterEqual::vjp( return vjps; } -array GreaterEqual::jvp( +std::vector GreaterEqual::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); - return zeros(shape, bool_, stream()); + return {zeros(shape, bool_, stream())}; } -std::pair Less::vmap( +std::pair, std::vector> Less::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {less(a, b, stream()), axes[0]}; + return {{less(a, b, stream())}, axes}; } std::vector Less::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -1221,24 +1252,24 @@ std::vector Less::vjp( return vjps; } -array Less::jvp( +std::vector Less::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); - return zeros(shape, bool_, stream()); + return {zeros(shape, bool_, stream())}; } -std::pair LessEqual::vmap( +std::pair, std::vector> LessEqual::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {less_equal(a, b, stream()), axes[0]}; + return {{less_equal(a, b, stream())}, axes}; } std::vector LessEqual::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -1247,22 +1278,22 @@ std::vector LessEqual::vjp( return vjps; } -array LessEqual::jvp( +std::vector LessEqual::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); - return zeros(shape, bool_, stream()); + return {zeros(shape, bool_, stream())}; } std::vector Log::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Log::jvp( +std::vector Log::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1273,132 +1304,140 @@ array Log::jvp( auto scale = 1 / std::log(base_ == Base::ten ? 10.0f : 2.0f); out = multiply(array(scale, out.dtype()), out, stream()); } - return out; + return {out}; } -std::pair Log::vmap( +std::pair, std::vector> Log::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); auto& in = inputs[0]; return { - array( - in.shape(), in.dtype(), std::make_unique(stream(), base_), {in}), - axes[0]}; + {array( + in.shape(), + in.dtype(), + std::make_unique(stream(), base_), + {in})}, + axes}; } std::vector Log1p::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Log1p::jvp( +std::vector Log1p::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); auto dtype = primals[0].dtype(); - return divide( - tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream()); + return {divide( + tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream())}; } -std::pair Log1p::vmap( +std::pair, std::vector> Log1p::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {log1p(inputs[0], stream()), axes[0]}; + return {{log1p(inputs[0], stream())}, axes}; } std::vector LogicalNot::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array LogicalNot::jvp( +std::vector LogicalNot::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return zeros_like(tangents[0], stream()); + return {zeros_like(tangents[0], stream())}; } -std::pair LogicalNot::vmap( +std::pair, std::vector> LogicalNot::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {logical_not(inputs[0], stream()), axes[0]}; + return {{logical_not(inputs[0], stream())}, axes}; } std::vector LogicalAnd::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 2); - - return {zeros_like(cotan, stream()), zeros_like(cotan, stream())}; + std::vector vjps = {zeros_like(cotangents[0], stream())}; + if (argnums.size() > 1) { + vjps.push_back(vjps.back()); + } + return vjps; } -array LogicalAnd::jvp( +std::vector LogicalAnd::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); - - return zeros_like(primals[0], stream()); + return {zeros_like(primals[0], stream())}; } -std::pair LogicalAnd::vmap( +std::pair, std::vector> LogicalAnd::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 2); assert(axes.size() == 2); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {logical_and(a, b, stream()), to_ax}; + return {{logical_and(a, b, stream())}, {to_ax}}; } std::vector LogicalOr::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 2); - - return {zeros_like(cotan, stream()), zeros_like(cotan, stream())}; + std::vector vjps = {zeros_like(cotangents[0], stream())}; + if (argnums.size() > 1) { + vjps.push_back(vjps.back()); + } + return vjps; } -array LogicalOr::jvp( +std::vector LogicalOr::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); - return zeros_like(primals[0], stream()); + return {zeros_like(primals[0], stream())}; } -std::pair LogicalOr::vmap( +std::pair, std::vector> LogicalOr::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 2); assert(axes.size() == 2); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {logical_or(a, b, stream()), to_ax}; + return {{logical_or(a, b, stream())}, {to_ax}}; } std::vector LogAddExp::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { auto a = primals[0]; auto b = primals[1]; @@ -1406,14 +1445,14 @@ std::vector LogAddExp::vjp( std::vector vjps; for (auto arg : argnums) { vjps.push_back(multiply( - cotan, + cotangents[0], arg == 0 ? s : subtract(array(1.0f, s.dtype()), s, stream()), stream())); } return vjps; } -array LogAddExp::jvp( +std::vector LogAddExp::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1431,21 +1470,22 @@ array LogAddExp::jvp( if (argnums.size() > 1) { out = add(out, jvp_fun(1), stream()); } - return out; + return {out}; } -std::pair LogAddExp::vmap( +std::pair, std::vector> LogAddExp::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {logaddexp(a, b, stream()), to_ax}; + return {{logaddexp(a, b, stream())}, {to_ax}}; } std::vector Matmul::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; + auto& cotan = cotangents[0]; std::vector reorder(cotan.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::iter_swap(reorder.end() - 1, reorder.end() - 2); @@ -1463,15 +1503,9 @@ std::vector Matmul::vjp( return vjps; } -std::pair Matmul::vmap( - const std::vector& inputs, - const std::vector& axes) { - return {array(1.0), 0}; -} - std::vector Maximum::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { auto& a = primals[0]; auto& b = primals[1]; @@ -1479,12 +1513,12 @@ std::vector Maximum::vjp( for (auto arg : argnums) { auto mask = (arg == 0) ? greater(a, b, stream()) : less_equal(a, b, stream()); - vjps.push_back(multiply(cotan, mask, stream())); + vjps.push_back(multiply(cotangents[0], mask, stream())); } - return vjps; + return {vjps}; } -array Maximum::jvp( +std::vector Maximum::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1500,19 +1534,19 @@ array Maximum::jvp( if (argnums.size() > 1) { out = add(out, jvp_fun(1), stream()); } - return out; + return {out}; } -std::pair Maximum::vmap( +std::pair, std::vector> Maximum::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {maximum(a, b, stream()), to_ax}; + return {{maximum(a, b, stream())}, {to_ax}}; } std::vector Minimum::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { auto& a = primals[0]; auto& b = primals[1]; @@ -1520,12 +1554,12 @@ std::vector Minimum::vjp( for (auto arg : argnums) { auto mask = (arg == 0) ? less(a, b, stream()) : greater_equal(a, b, stream()); - vjps.push_back(multiply(cotan, mask, stream())); + vjps.push_back(multiply(cotangents[0], mask, stream())); } return vjps; } -array Minimum::jvp( +std::vector Minimum::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1541,17 +1575,17 @@ array Minimum::jvp( if (argnums.size() > 1) { out = add(out, jvp_fun(1), stream()); } - return out; + return {out}; } -std::pair Minimum::vmap( +std::pair, std::vector> Minimum::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {minimum(a, b, stream()), to_ax}; + return {{minimum(a, b, stream())}, {to_ax}}; } -array Multiply::jvp( +std::vector Multiply::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1561,61 +1595,61 @@ array Multiply::jvp( arg = argnums[1]; jvp = add(jvp, multiply(tangents[1], primals[1 - arg], stream()), stream()); } - return jvp; + return {jvp}; } std::vector Multiply::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { - vjps.push_back(multiply(primals[1 - arg], cotan, stream())); + vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); } return vjps; } -std::pair Multiply::vmap( +std::pair, std::vector> Multiply::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {multiply(a, b, stream()), to_ax}; + return {{multiply(a, b, stream())}, {to_ax}}; } std::vector Negative::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Negative::jvp( +std::vector Negative::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return negative(tangents[0], stream()); + return {negative(tangents[0], stream())}; } -std::pair Negative::vmap( +std::pair, std::vector> Negative::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {negative(inputs[0], stream()), axes[0]}; + return {{negative(inputs[0], stream())}, axes}; } -std::pair NotEqual::vmap( +std::pair, std::vector> NotEqual::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {not_equal(a, b, stream()), axes[0]}; + return {{not_equal(a, b, stream())}, axes}; } std::vector NotEqual::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -1624,20 +1658,21 @@ std::vector NotEqual::vjp( return vjps; } -array NotEqual::jvp( +std::vector NotEqual::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); - return zeros(shape, bool_, stream()); + return {zeros(shape, bool_, stream())}; } std::vector Pad::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(argnums.size() == 1 && argnums[0] == 0); + auto& cotan = cotangents[0]; std::vector start(cotan.ndim(), 0); std::vector stop = cotan.shape(); @@ -1651,22 +1686,22 @@ std::vector Pad::vjp( return {out}; } -array Pad::jvp( +std::vector Pad::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(argnums.size() == 1 && argnums[0] == 0); - return pad( - tangents[0], - axes_, - low_pad_size_, - high_pad_size_, - array(0, tangents[0].dtype()), - stream()); + return { + pad(tangents[0], + axes_, + low_pad_size_, + high_pad_size_, + array(0, tangents[0].dtype()), + stream())}; } -std::pair Pad::vmap( +std::pair, std::vector> Pad::vmap( const std::vector& inputs, const std::vector& axes) { throw std::runtime_error("Pad vmap is NYI."); @@ -1681,12 +1716,12 @@ bool Pad::is_equivalent(const Primitive& other) const { std::vector Partition::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Partition::jvp( +std::vector Partition::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1694,16 +1729,16 @@ array Partition::jvp( assert(tangents.size() == 1); auto sort_idx = argpartition(primals[0], kth_, axis_, stream()); auto out = take_along_axis(tangents[0], sort_idx, axis_, stream()); - return out; + return {out}; } -std::pair Partition::vmap( +std::pair, std::vector> Partition::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {partition(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; + return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } bool Partition::is_equivalent(const Primitive& other) const { @@ -1713,7 +1748,7 @@ bool Partition::is_equivalent(const Primitive& other) const { std::vector Power::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { @@ -1731,30 +1766,31 @@ std::vector Power::vjp( power(primals[0], primals[1], stream()), stream())); } - vjps.back() = multiply(cotan, vjps.back(), stream()); + vjps.back() = multiply(cotangents[0], vjps.back(), stream()); } return vjps; } -array Power::jvp( +std::vector Power::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - auto jvp = vjp(primals, tangents[0], {argnums[0]})[0]; + auto jvp = vjp(primals, {tangents[0]}, {argnums[0]}); if (argnums.size() > 1) { - jvp = add(jvp, vjp(primals, tangents[1], {argnums[1]})[0], stream()); + jvp[0] = + add(jvp[0], vjp(primals, {tangents[1]}, {argnums[1]})[0], stream()); } return jvp; } -std::pair Power::vmap( +std::pair, std::vector> Power::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {power(a, b, stream()), to_ax}; + return {{power(a, b, stream())}, {to_ax}}; } -std::pair QuantizedMatmul::vmap( +std::pair, std::vector> QuantizedMatmul::vmap( const std::vector& inputs, const std::vector& axes) { throw std::runtime_error("QuantizedMatmul::vmap NYI"); @@ -1762,7 +1798,7 @@ std::pair QuantizedMatmul::vmap( std::vector QuantizedMatmul::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; @@ -1771,7 +1807,7 @@ std::vector QuantizedMatmul::vjp( // gradient wrt to x if (arg == 0) { vjps.push_back(quantized_matmul( - cotan, + cotangents[0], primals[1], primals[2], primals[3], @@ -1790,7 +1826,7 @@ std::vector QuantizedMatmul::vjp( return vjps; } -array QuantizedMatmul::jvp( +std::vector QuantizedMatmul::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -1802,7 +1838,7 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const { return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_; } -std::pair RandomBits::vmap( +std::pair, std::vector> RandomBits::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); @@ -1838,7 +1874,7 @@ std::pair RandomBits::vmap( get_dtype(), std::make_unique(stream(), shape, width_), {key}); - return {out, kax}; + return {{out}, {kax}}; } bool RandomBits::is_equivalent(const Primitive& other) const { @@ -1846,7 +1882,7 @@ bool RandomBits::is_equivalent(const Primitive& other) const { return shape_ == r_other.shape_; } -std::pair Reshape::vmap( +std::pair, std::vector> Reshape::vmap( const std::vector& inputs, const std::vector& axes) { // Transpose the input so that the vmap dim is first. @@ -1860,27 +1896,27 @@ std::pair Reshape::vmap( auto out = transpose(in, reorder, stream()); shape_.insert(shape_.begin(), in.shape()[ax]); // Reshape the transposed input to the new shape. - return {reshape(out, shape_, stream()), 0}; + return {{reshape(out, shape_, stream())}, {0}}; } std::vector Reshape::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); - return {reshape(cotan, primals[0].shape(), stream())}; + return {reshape(cotangents[0], primals[0].shape(), stream())}; } -array Reshape::jvp( +std::vector Reshape::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); - return reshape(tangents[0], shape_, stream()); + return {reshape(tangents[0], shape_, stream())}; } bool Reshape::is_equivalent(const Primitive& other) const { @@ -1890,7 +1926,7 @@ bool Reshape::is_equivalent(const Primitive& other) const { std::vector Reduce::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { auto in = primals[0]; @@ -1898,7 +1934,7 @@ std::vector Reduce::vjp( for (auto ax : axes_) { shape[ax] = 1; } - + auto& cotan = cotangents[0]; if (reduce_type_ == Reduce::Sum) { return { broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())}; @@ -1982,11 +2018,10 @@ std::vector Reduce::vjp( } } -std::pair Reduce::vmap( +std::pair, std::vector> Reduce::vmap( const std::vector& inputs, const std::vector& axes) { - // TODO implement - return {array(1.0f), axes[0]}; + throw std::runtime_error("Reduce::vmap not yet implemented."); } bool Reduce::is_equivalent(const Primitive& other) const { @@ -1996,63 +2031,61 @@ bool Reduce::is_equivalent(const Primitive& other) const { std::vector Round::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Round::jvp( +std::vector Round::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return zeros_like(primals[0], stream()); + return {zeros_like(primals[0], stream())}; } -std::pair Round::vmap( +std::pair, std::vector> Round::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {round(inputs[0], stream()), axes[0]}; + return {{round(inputs[0], stream())}, axes}; } -std::pair Scan::vmap( +std::pair, std::vector> Scan::vmap( const std::vector& inputs, const std::vector& axes) { auto& in = inputs[0]; - auto axis = axes[0]; - auto out_dtype = (in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype(); return { - array( + {array( in.shape(), out_dtype, std::make_unique( stream(), reduce_type_, - axis_ + (axis <= axis_), + axis_ + (axes[0] <= axis_), reverse_, inclusive_), - {in}), - axis}; + {in})}, + axes}; } std::vector Scan::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums[0] == 0); if (reduce_type_ == Scan::Sum) { - return {cumsum(cotan, axis_, !reverse_, inclusive_, stream())}; + return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; } else if (reduce_type_ == Scan::Prod) { // TODO: Make it numerically stable when we introduce where() auto prod = cumprod(primals[0], axis_, reverse_, inclusive_, stream()); - auto partial_grads = multiply(prod, cotan, stream()); + auto partial_grads = multiply(prod, cotangents[0], stream()); auto accum_grads = cumsum(partial_grads, axis_, !reverse_, inclusive_, stream()); return {divide(accum_grads, primals[0], stream())}; @@ -2062,7 +2095,7 @@ std::vector Scan::vjp( } } -array Scan::jvp( +std::vector Scan::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -2070,7 +2103,7 @@ array Scan::jvp( assert(argnums[0] == 0); if (reduce_type_ == Scan::Sum) { - return cumsum(tangents[0], axis_, reverse_, inclusive_, stream()); + return {cumsum(tangents[0], axis_, reverse_, inclusive_, stream())}; } else { throw std::runtime_error( "JVP is not implemented for cumulative prod/min/max"); @@ -2091,12 +2124,12 @@ bool Scatter::is_equivalent(const Primitive& other) const { std::vector Sigmoid::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Sigmoid::jvp( +std::vector Sigmoid::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -2105,90 +2138,90 @@ array Sigmoid::jvp( auto s = sigmoid(primals[0], stream()); auto sprime = multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream()); - return multiply(tangents[0], sprime, stream()); + return {multiply(tangents[0], sprime, stream())}; } -std::pair Sigmoid::vmap( +std::pair, std::vector> Sigmoid::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {sigmoid(inputs[0], stream()), axes[0]}; + return {{sigmoid(inputs[0], stream())}, axes}; } std::vector Sign::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Sign::jvp( +std::vector Sign::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return zeros(primals[0].shape(), primals[0].dtype(), stream()); + return {zeros(primals[0].shape(), primals[0].dtype(), stream())}; } -std::pair Sign::vmap( +std::pair, std::vector> Sign::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {sign(inputs[0], stream()), axes[0]}; + return {{sign(inputs[0], stream())}, axes}; } std::vector Sin::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Sin::jvp( +std::vector Sin::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return multiply(tangents[0], cos(primals[0], stream()), stream()); + return {multiply(tangents[0], cos(primals[0], stream()), stream())}; } -std::pair Sin::vmap( +std::pair, std::vector> Sin::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {sin(inputs[0], stream()), axes[0]}; + return {{sin(inputs[0], stream())}, axes}; } std::vector Sinh::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Sinh::jvp( +std::vector Sinh::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); - return multiply(tangents[0], cosh(primals[0], stream()), stream()); + return {multiply(tangents[0], cosh(primals[0], stream()), stream())}; } -std::pair Sinh::vmap( +std::pair, std::vector> Sinh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {sinh(inputs[0], stream()), axes[0]}; + return {{sinh(inputs[0], stream())}, axes}; } -std::pair Slice::vmap( +std::pair, std::vector> Slice::vmap( const std::vector& inputs, const std::vector& axes) { auto start = start_indices_; @@ -2199,12 +2232,12 @@ std::pair Slice::vmap( start.insert(start.begin() + ax, 0); stop.insert(stop.begin() + ax, input.shape(ax)); strides.insert(strides.begin() + ax, 1); - return {slice(input, start, stop, strides, stream()), ax}; + return {{slice(input, start, stop, strides, stream())}, {ax}}; } std::vector Slice::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { // Check inputs assert(primals.size() == 1); @@ -2229,8 +2262,8 @@ std::vector Slice::vjp( } } - // Transpose and reshape cotan - auto cotan_ = cotan; + // Transpose and reshape cotangents + auto cotan = cotangents[0]; if (!ind_axes.empty()) { std::vector cotan_shape; for (auto ax : ind_axes) { @@ -2246,8 +2279,8 @@ std::vector Slice::vjp( cotan_axes.push_back(i); } } - cotan_ = - reshape(transpose(cotan_, cotan_axes, stream()), cotan_shape, stream()); + cotan = + reshape(transpose(cotan, cotan_axes, stream()), cotan_shape, stream()); } // Make indices broadcastable @@ -2264,16 +2297,16 @@ std::vector Slice::vjp( ind_axes.end(), single_ind_axes.begin(), single_ind_axes.end()); return {scatter_add( - zeros_like(primals[0], stream()), inds, cotan_, ind_axes, stream())}; + zeros_like(primals[0], stream()), inds, cotan, ind_axes, stream())}; } -array Slice::jvp( +std::vector Slice::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { // Check inputs assert(primals.size() == 1); - return slice(tangents[0], start_indices_, end_indices_, strides_, stream()); + return {slice(tangents[0], start_indices_, end_indices_, strides_, stream())}; } bool Slice::is_equivalent(const Primitive& other) const { @@ -2283,7 +2316,7 @@ bool Slice::is_equivalent(const Primitive& other) const { end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); } -std::pair Softmax::vmap( +std::pair, std::vector> Softmax::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); @@ -2298,17 +2331,17 @@ std::pair Softmax::vmap( } else { softmax_axes.push_back(-2); } - return {softmax(inputs[0], softmax_axes, stream()), axes[0]}; + return {{softmax(inputs[0], softmax_axes, stream())}, axes}; } std::vector Softmax::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Softmax::jvp( +std::vector Softmax::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -2316,27 +2349,28 @@ array Softmax::jvp( assert(tangents.size() == 1); auto s = softmax(primals[0], std::vector{-1}, stream()); auto sv = multiply(s, tangents[0], stream()); - return subtract( - sv, multiply(s, sum(sv, std::vector{-1}, true, stream()), stream())); + return {subtract( + sv, + multiply(s, sum(sv, std::vector{-1}, true, stream()), stream()))}; } -std::pair Sort::vmap( +std::pair, std::vector> Sort::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {sort(inputs[0], axis_ + (axes[0] <= axis_), stream()), axes[0]}; + return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; } std::vector Sort::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Sort::jvp( +std::vector Sort::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -2344,7 +2378,7 @@ array Sort::jvp( assert(tangents.size() == 1); auto sort_idx = argsort(primals[0], axis_, stream()); auto out = take_along_axis(tangents[0], sort_idx, axis_, stream()); - return out; + return {out}; } bool Sort::is_equivalent(const Primitive& other) const { @@ -2354,39 +2388,39 @@ bool Sort::is_equivalent(const Primitive& other) const { std::vector Square::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Square::jvp( +std::vector Square::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(tangents.size() == 1); - return multiply( + return {multiply( primals[0], multiply(array(2, primals[0].dtype()), tangents[0], stream()), - stream()); + stream())}; } -std::pair Square::vmap( +std::pair, std::vector> Square::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {square(inputs[0], stream()), axes[0]}; + return {{square(inputs[0], stream())}, axes}; } std::vector Sqrt::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Sqrt::jvp( +std::vector Sqrt::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -2396,26 +2430,26 @@ array Sqrt::jvp( if (recip_) { auto one_over_x_root_x = divide(rsqrt(primals[0], stream()), primals[0], stream()); - return multiply( + return {multiply( multiply(array(-0.5, dtype), tangents[0], stream()), one_over_x_root_x, - stream()); + stream())}; } - return divide( + return {divide( multiply(array(0.5, dtype), tangents[0], stream()), sqrt(primals[0], stream()), - stream()); + stream())}; } -std::pair Sqrt::vmap( +std::pair, std::vector> Sqrt::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - if (recip_) - return {rsqrt(inputs[0], stream()), axes[0]}; - - return {sqrt(inputs[0], stream()), axes[0]}; + if (recip_) { + return {{rsqrt(inputs[0], stream())}, axes}; + } + return {{sqrt(inputs[0], stream())}, axes}; } bool Sqrt::is_equivalent(const Primitive& other) const { @@ -2423,19 +2457,19 @@ bool Sqrt::is_equivalent(const Primitive& other) const { return recip_ == s_other.recip_; } -std::pair StopGradient::vmap( +std::pair, std::vector> StopGradient::vmap( const std::vector& inputs, const std::vector& axes) { - return {inputs[0], axes[0]}; + return {inputs, axes}; }; std::vector Subtract::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { std::vector vjps; for (auto arg : argnums) { - auto vjp = cotan; + auto vjp = cotangents[0]; if (arg == 1) { vjp = negative(vjp, stream()); } @@ -2444,7 +2478,7 @@ std::vector Subtract::vjp( return vjps; } -array Subtract::jvp( +std::vector Subtract::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -2456,69 +2490,69 @@ array Subtract::jvp( if (argnums.size() > 1) { out = add(out, jvp_fun(1), stream()); } - return out; + return {out}; } -std::pair Subtract::vmap( +std::pair, std::vector> Subtract::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {subtract(a, b, stream()), to_ax}; + return {{subtract(a, b, stream())}, {to_ax}}; } std::vector Tan::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Tan::jvp( +std::vector Tan::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); array cos_sq = square(cos(primals[0], stream()), stream()); - return divide(tangents[0], cos_sq, stream()); + return {divide(tangents[0], cos_sq, stream())}; } -std::pair Tan::vmap( +std::pair, std::vector> Tan::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {tan(inputs[0], stream()), axes[0]}; + return {{tan(inputs[0], stream())}, axes}; } std::vector Tanh::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { - return {jvp(primals, {cotan}, argnums)}; + return jvp(primals, cotangents, argnums); } -array Tanh::jvp( +std::vector Tanh::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); array cosh_sq = square(cosh(primals[0], stream()), stream()); - return divide(tangents[0], cosh_sq, stream()); + return {divide(tangents[0], cosh_sq, stream())}; } -std::pair Tanh::vmap( +std::pair, std::vector> Tanh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - return {tanh(inputs[0], stream()), axes[0]}; + return {{tanh(inputs[0], stream())}, axes}; } std::vector Transpose::vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); @@ -2526,19 +2560,19 @@ std::vector Transpose::vjp( for (int i = 0; i < axes_.size(); ++i) { iaxes[axes_[i]] = i; } - return {transpose(cotan, iaxes, stream())}; + return {transpose(cotangents[0], iaxes, stream())}; } -array Transpose::jvp( +std::vector Transpose::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 1); assert(tangents.size() == 1); - return transpose(tangents[0], axes_, stream()); + return {transpose(tangents[0], axes_, stream())}; } -std::pair Transpose::vmap( +std::pair, std::vector> Transpose::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); @@ -2550,7 +2584,7 @@ std::pair Transpose::vmap( } } axes_.insert(axes_.begin() + vdim, vdim); - return {transpose(inputs[0], axes_, stream()), vdim}; + return {{transpose(inputs[0], axes_, stream())}, {vdim}}; } bool Transpose::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index c69b830a2..786da59d1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2,20 +2,25 @@ #pragma once -#include "array.h" -#include "device.h" -#include "io/load.h" -#include "stream.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/io/load.h" +#include "mlx/stream.h" + +#define DEFINE_VMAP() \ + virtual std::pair, std::vector> vmap( \ + const std::vector& inputs, const std::vector& axes) \ + override; #define DEFINE_GRADS() \ - array jvp( \ + std::vector jvp( \ const std::vector& primals, \ const std::vector& tangents, \ const std::vector& argnums) override; \ \ std::vector vjp( \ const std::vector& primals, \ - const array& cotan, \ + const std::vector& cotangents, \ const std::vector& argnums) override; #define DEFINE_PRINT(PRIMITIVE) \ @@ -47,18 +52,22 @@ class Primitive { /** * A primitive must know how to evaluate itself on - * the CPU/GPU for the given inputs and populate the output array. + * the CPU/GPU for the given inputs and populate the output arrays. * * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ - virtual void eval_cpu(const std::vector& inputs, array& out) = 0; - virtual void eval_gpu(const std::vector& inputs, array& out) = 0; + virtual void eval_cpu( + const std::vector& inputs, + std::vector& outputs) = 0; + virtual void eval_gpu( + const std::vector& inputs, + std::vector& outputs) = 0; /** * The Jacobian-vector product. */ - virtual array jvp( + virtual std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums); @@ -68,16 +77,16 @@ class Primitive { */ virtual std::vector vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums); /** * The primitive must know how to vectorize itself across - * the given axes. The output is a pair containing the array - * representing the vectorized computation and the axis which - * corresponds to the output vectorized dimension. + * the given axes. The output is a pair containing the output arrays + * representing the vectorized computation and the axes which + * corresponds to the vectorized dimensions of each output. */ - virtual std::pair vmap( + virtual std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes); @@ -100,17 +109,42 @@ class Primitive { Stream stream_; }; -class Abs : public Primitive { +class UnaryPrimitive : public Primitive { + /** + * An abstract base class for a primitive with a single output. + */ public: - explicit Abs(Stream stream) : Primitive(stream){}; + explicit UnaryPrimitive(Stream stream) : Primitive(stream) {} + + virtual void eval_cpu(const std::vector& inputs, array& output) = 0; + virtual void eval_gpu(const std::vector& inputs, array& output) = 0; + + inline void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override { + eval_cpu(inputs, outputs[0]); + } + inline void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override { + eval_gpu(inputs, outputs[0]); + } + + virtual ~UnaryPrimitive() = default; + UnaryPrimitive(const UnaryPrimitive& other) = delete; + UnaryPrimitive(UnaryPrimitive&& other) = delete; + UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete; + UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; +}; + +class Abs : public UnaryPrimitive { + public: + explicit Abs(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -119,17 +153,14 @@ class Abs : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Add : public Primitive { +class Add : public UnaryPrimitive { public: - explicit Add(Stream stream) : Primitive(stream){}; + explicit Add(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Add) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -138,10 +169,10 @@ class Add : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Arange : public Primitive { +class Arange : public UnaryPrimitive { public: explicit Arange(Stream stream, double start, double stop, double step) - : Primitive(stream), start_(start), stop_(stop), step_(step){}; + : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -157,17 +188,14 @@ class Arange : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArcCos : public Primitive { +class ArcCos : public UnaryPrimitive { public: - explicit ArcCos(Stream stream) : Primitive(stream){}; + explicit ArcCos(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -176,17 +204,14 @@ class ArcCos : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArcCosh : public Primitive { +class ArcCosh : public UnaryPrimitive { public: - explicit ArcCosh(Stream stream) : Primitive(stream){}; + explicit ArcCosh(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -195,17 +220,14 @@ class ArcCosh : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArcSin : public Primitive { +class ArcSin : public UnaryPrimitive { public: - explicit ArcSin(Stream stream) : Primitive(stream){}; + explicit ArcSin(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -214,17 +236,14 @@ class ArcSin : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArcSinh : public Primitive { +class ArcSinh : public UnaryPrimitive { public: - explicit ArcSinh(Stream stream) : Primitive(stream){}; + explicit ArcSinh(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -233,17 +252,14 @@ class ArcSinh : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArcTan : public Primitive { +class ArcTan : public UnaryPrimitive { public: - explicit ArcTan(Stream stream) : Primitive(stream){}; + explicit ArcTan(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -252,17 +268,14 @@ class ArcTan : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArcTanh : public Primitive { +class ArcTanh : public UnaryPrimitive { public: - explicit ArcTanh(Stream stream) : Primitive(stream){}; + explicit ArcTanh(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -271,18 +284,15 @@ class ArcTanh : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArgPartition : public Primitive { +class ArgPartition : public UnaryPrimitive { public: explicit ArgPartition(Stream stream, int kth, int axis) - : Primitive(stream), kth_(kth), axis_(axis){}; + : UnaryPrimitive(stream), kth_(kth), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_PRINT(ArgPartition) bool is_equivalent(const Primitive& other) const override; @@ -293,7 +303,7 @@ class ArgPartition : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArgReduce : public Primitive { +class ArgReduce : public UnaryPrimitive { public: enum ReduceType { ArgMin, @@ -301,7 +311,7 @@ class ArgReduce : public Primitive { }; explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) - : Primitive(stream), reduce_type_(reduce_type), axis_(axis){}; + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -316,17 +326,15 @@ class ArgReduce : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ArgSort : public Primitive { +class ArgSort : public UnaryPrimitive { public: - explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; + explicit ArgSort(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_PRINT(ArgSort) bool is_equivalent(const Primitive& other) const override; @@ -336,18 +344,15 @@ class ArgSort : public Primitive { void eval(const std::vector& inputs, array& out); }; -class AsType : public Primitive { +class AsType : public UnaryPrimitive { public: explicit AsType(Stream stream, Dtype dtype) - : Primitive(stream), dtype_(dtype){}; + : UnaryPrimitive(stream), dtype_(dtype){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(AsType) bool is_equivalent(const Primitive& other) const override; @@ -358,14 +363,17 @@ class AsType : public Primitive { void eval(const std::vector& inputs, array& out); }; -class AsStrided : public Primitive { +class AsStrided : public UnaryPrimitive { public: explicit AsStrided( Stream stream, const std::vector& shape, const std::vector& strides, size_t offset) - : Primitive(stream), shape_(shape), strides_(strides), offset_(offset){}; + : UnaryPrimitive(stream), + shape_(shape), + strides_(strides), + offset_(offset){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -382,18 +390,15 @@ class AsStrided : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Broadcast : public Primitive { +class Broadcast : public UnaryPrimitive { public: explicit Broadcast(Stream stream, const std::vector& shape) - : Primitive(stream), shape_(shape){}; + : UnaryPrimitive(stream), shape_(shape){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Broadcast) bool is_equivalent(const Primitive& other) const override; @@ -404,17 +409,14 @@ class Broadcast : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Ceil : public Primitive { +class Ceil : public UnaryPrimitive { public: - explicit Ceil(Stream stream) : Primitive(stream){}; + explicit Ceil(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -423,18 +425,15 @@ class Ceil : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Concatenate : public Primitive { +class Concatenate : public UnaryPrimitive { public: explicit Concatenate(Stream stream, int axis) - : Primitive(stream), axis_(axis){}; + : UnaryPrimitive(stream), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Concatenate) bool is_equivalent(const Primitive& other) const override; @@ -445,7 +444,7 @@ class Concatenate : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Convolution : public Primitive { +class Convolution : public UnaryPrimitive { public: explicit Convolution( Stream stream, @@ -453,7 +452,7 @@ class Convolution : public Primitive { const std::vector& kernel_strides, const std::vector& kernel_dilation, const std::vector& input_dilation) - : Primitive(stream), + : UnaryPrimitive(stream), padding_(padding), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), @@ -464,7 +463,7 @@ class Convolution : public Primitive { std::vector vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) override; DEFINE_PRINT(Convolution) @@ -479,17 +478,14 @@ class Convolution : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Copy : public Primitive { +class Copy : public UnaryPrimitive { public: - explicit Copy(Stream stream) : Primitive(stream){}; + explicit Copy(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -498,17 +494,14 @@ class Copy : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Cos : public Primitive { +class Cos : public UnaryPrimitive { public: - explicit Cos(Stream stream) : Primitive(stream){}; + explicit Cos(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -517,17 +510,14 @@ class Cos : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Cosh : public Primitive { +class Cosh : public UnaryPrimitive { public: - explicit Cosh(Stream stream) : Primitive(stream){}; + explicit Cosh(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -536,17 +526,14 @@ class Cosh : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Divide : public Primitive { +class Divide : public UnaryPrimitive { public: - explicit Divide(Stream stream) : Primitive(stream){}; + explicit Divide(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -555,17 +542,32 @@ class Divide : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Remainder : public Primitive { +class DivMod : public Primitive { public: - explicit Remainder(Stream stream) : Primitive(stream){}; + explicit DivMod(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(DivMod) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + +class Remainder : public UnaryPrimitive { + public: + explicit Remainder(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -574,18 +576,15 @@ class Remainder : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Equal : public Primitive { +class Equal : public UnaryPrimitive { public: explicit Equal(Stream stream, bool equal_nan = false) - : Primitive(stream), equal_nan_(equal_nan){}; + : UnaryPrimitive(stream), equal_nan_(equal_nan){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Equal) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -595,17 +594,14 @@ class Equal : public Primitive { bool equal_nan_; }; -class Erf : public Primitive { +class Erf : public UnaryPrimitive { public: - explicit Erf(Stream stream) : Primitive(stream){}; + explicit Erf(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -614,17 +610,14 @@ class Erf : public Primitive { void eval(const std::vector& inputs, array& out); }; -class ErfInv : public Primitive { +class ErfInv : public UnaryPrimitive { public: - explicit ErfInv(Stream stream) : Primitive(stream){}; + explicit ErfInv(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -633,17 +626,14 @@ class ErfInv : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Exp : public Primitive { +class Exp : public UnaryPrimitive { public: - explicit Exp(Stream stream) : Primitive(stream){}; + explicit Exp(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -652,22 +642,19 @@ class Exp : public Primitive { void eval(const std::vector& inputs, array& out); }; -class FFT : public Primitive { +class FFT : public UnaryPrimitive { public: explicit FFT( Stream stream, const std::vector& axes, bool inverse, bool real) - : Primitive(stream), axes_(axes), inverse_(inverse), real_(real){}; + : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(FFT) @@ -681,17 +668,14 @@ class FFT : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Floor : public Primitive { +class Floor : public UnaryPrimitive { public: - explicit Floor(Stream stream) : Primitive(stream){}; + explicit Floor(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -700,17 +684,14 @@ class Floor : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Full : public Primitive { +class Full : public UnaryPrimitive { public: - explicit Full(Stream stream) : Primitive(stream){}; + explicit Full(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Full) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -719,21 +700,18 @@ class Full : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Gather : public Primitive { +class Gather : public UnaryPrimitive { public: explicit Gather( Stream stream, const std::vector& axes, const std::vector& slice_sizes) - : Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){}; + : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Gather) bool is_equivalent(const Primitive& other) const override; @@ -744,17 +722,14 @@ class Gather : public Primitive { std::vector slice_sizes_; }; -class Greater : public Primitive { +class Greater : public UnaryPrimitive { public: - explicit Greater(Stream stream) : Primitive(stream){}; + explicit Greater(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -763,17 +738,14 @@ class Greater : public Primitive { void eval(const std::vector& inputs, array& out); }; -class GreaterEqual : public Primitive { +class GreaterEqual : public UnaryPrimitive { public: - explicit GreaterEqual(Stream stream) : Primitive(stream){}; + explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -782,17 +754,14 @@ class GreaterEqual : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Less : public Primitive { +class Less : public UnaryPrimitive { public: - explicit Less(Stream stream) : Primitive(stream){}; + explicit Less(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Less) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -801,17 +770,14 @@ class Less : public Primitive { void eval(const std::vector& inputs, array& out); }; -class LessEqual : public Primitive { +class LessEqual : public UnaryPrimitive { public: - explicit LessEqual(Stream stream) : Primitive(stream){}; + explicit LessEqual(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -820,14 +786,14 @@ class LessEqual : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Load : public Primitive { +class Load : public UnaryPrimitive { public: explicit Load( Stream stream, std::shared_ptr reader, size_t offset, bool swap_endianness = false) - : Primitive(stream), + : UnaryPrimitive(stream), reader_(reader), offset_(offset), swap_endianness_(swap_endianness){}; @@ -844,19 +810,17 @@ class Load : public Primitive { bool swap_endianness_; }; -class Log : public Primitive { +class Log : public UnaryPrimitive { public: enum Base { two, ten, e }; - explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){}; + explicit Log(Stream stream, Base base) + : UnaryPrimitive(stream), base_(base){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Log) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -866,17 +830,14 @@ class Log : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Log1p : public Primitive { +class Log1p : public UnaryPrimitive { public: - explicit Log1p(Stream stream) : Primitive(stream){}; + explicit Log1p(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Log1p) @@ -884,17 +845,14 @@ class Log1p : public Primitive { void eval(const std::vector& inputs, array& out); }; -class LogicalNot : public Primitive { +class LogicalNot : public UnaryPrimitive { public: - explicit LogicalNot(Stream stream) : Primitive(stream){}; + explicit LogicalNot(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -903,17 +861,14 @@ class LogicalNot : public Primitive { void eval(const std::vector& inputs, array& out); }; -class LogicalAnd : public Primitive { +class LogicalAnd : public UnaryPrimitive { public: - explicit LogicalAnd(Stream stream) : Primitive(stream){}; + explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -922,17 +877,14 @@ class LogicalAnd : public Primitive { void eval(const std::vector& inputs, array& out); }; -class LogicalOr : public Primitive { +class LogicalOr : public UnaryPrimitive { public: - explicit LogicalOr(Stream stream) : Primitive(stream){}; + explicit LogicalOr(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -941,17 +893,14 @@ class LogicalOr : public Primitive { void eval(const std::vector& inputs, array& out); }; -class LogAddExp : public Primitive { +class LogAddExp : public UnaryPrimitive { public: - explicit LogAddExp(Stream stream) : Primitive(stream){}; + explicit LogAddExp(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -960,37 +909,30 @@ class LogAddExp : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Matmul : public Primitive { +class Matmul : public UnaryPrimitive { public: - explicit Matmul(Stream stream) : Primitive(stream){}; + explicit Matmul(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::vector vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - DEFINE_PRINT(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() }; -class Maximum : public Primitive { +class Maximum : public UnaryPrimitive { public: - explicit Maximum(Stream stream) : Primitive(stream){}; + explicit Maximum(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -999,17 +941,14 @@ class Maximum : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Minimum : public Primitive { +class Minimum : public UnaryPrimitive { public: - explicit Minimum(Stream stream) : Primitive(stream){}; + explicit Minimum(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1018,17 +957,14 @@ class Minimum : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Multiply : public Primitive { +class Multiply : public UnaryPrimitive { public: - explicit Multiply(Stream stream) : Primitive(stream){}; + explicit Multiply(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1037,17 +973,14 @@ class Multiply : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Negative : public Primitive { +class Negative : public UnaryPrimitive { public: - explicit Negative(Stream stream) : Primitive(stream){}; + explicit Negative(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1056,17 +989,14 @@ class Negative : public Primitive { void eval(const std::vector& inputs, array& out); }; -class NotEqual : public Primitive { +class NotEqual : public UnaryPrimitive { public: - explicit NotEqual(Stream stream) : Primitive(stream){}; + explicit NotEqual(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1075,14 +1005,14 @@ class NotEqual : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Pad : public Primitive { +class Pad : public UnaryPrimitive { public: explicit Pad( Stream stream, const std::vector& axes, const std::vector& low_pad_size, const std::vector& high_pad_size) - : Primitive(stream), + : UnaryPrimitive(stream), axes_(axes), low_pad_size_(low_pad_size), high_pad_size_(high_pad_size){}; @@ -1090,10 +1020,7 @@ class Pad : public Primitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Pad) bool is_equivalent(const Primitive& other) const override; @@ -1106,18 +1033,15 @@ class Pad : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Partition : public Primitive { +class Partition : public UnaryPrimitive { public: explicit Partition(Stream stream, int kth, int axis) - : Primitive(stream), kth_(kth), axis_(axis){}; + : UnaryPrimitive(stream), kth_(kth), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Partition) bool is_equivalent(const Primitive& other) const override; @@ -1129,17 +1053,14 @@ class Partition : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Power : public Primitive { +class Power : public UnaryPrimitive { public: - explicit Power(Stream stream) : Primitive(stream){}; + explicit Power(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Power) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1148,14 +1069,14 @@ class Power : public Primitive { void eval(const std::vector& inputs, array& out); }; -class QuantizedMatmul : public Primitive { +class QuantizedMatmul : public UnaryPrimitive { public: explicit QuantizedMatmul( Stream stream, int group_size, int bits, bool transpose) - : Primitive(stream), + : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), transpose_(transpose){}; @@ -1163,10 +1084,7 @@ class QuantizedMatmul : public Primitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(QuantizedMatmul) bool is_equivalent(const Primitive& other) const override; @@ -1179,18 +1097,15 @@ class QuantizedMatmul : public Primitive { void eval(const std::vector& inputs, array& out); }; -class RandomBits : public Primitive { +class RandomBits : public UnaryPrimitive { public: explicit RandomBits(Stream stream, const std::vector& shape, int width) - : Primitive(stream), shape_(shape), width_(width){}; + : UnaryPrimitive(stream), shape_(shape), width_(width){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_PRINT(RandomBits) bool is_equivalent(const Primitive& other) const override; @@ -1201,18 +1116,15 @@ class RandomBits : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Reshape : public Primitive { +class Reshape : public UnaryPrimitive { public: explicit Reshape(Stream stream, const std::vector& shape) - : Primitive(stream), shape_(shape){}; + : UnaryPrimitive(stream), shape_(shape){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Reshape) bool is_equivalent(const Primitive& other) const override; @@ -1223,7 +1135,7 @@ class Reshape : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Reduce : public Primitive { +class Reduce : public UnaryPrimitive { public: enum ReduceType { And, Or, Sum, Prod, Min, Max }; @@ -1231,17 +1143,16 @@ class Reduce : public Primitive { Stream stream, ReduceType reduce_type, const std::vector& axes) - : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; + DEFINE_VMAP() + std::vector vjp( const std::vector& primals, - const array& cotan, + const std::vector& cotangents, const std::vector& argnums) override; void print(std::ostream& os) override { @@ -1275,17 +1186,14 @@ class Reduce : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Round : public Primitive { +class Round : public UnaryPrimitive { public: - explicit Round(Stream stream) : Primitive(stream){}; + explicit Round(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Round) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1294,7 +1202,7 @@ class Round : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Scan : public Primitive { +class Scan : public UnaryPrimitive { public: enum ReduceType { Max, Min, Sum, Prod }; @@ -1304,7 +1212,7 @@ class Scan : public Primitive { int axis, bool reverse, bool inclusive) - : Primitive(stream), + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis), reverse_(reverse), @@ -1313,11 +1221,9 @@ class Scan : public Primitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS(); + void print(std::ostream& os) override { os << "Cum"; switch (reduce_type_) { @@ -1347,7 +1253,7 @@ class Scan : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Scatter : public Primitive { +class Scatter : public UnaryPrimitive { public: enum ReduceType { Max, Min, Sum, Prod, None }; @@ -1355,7 +1261,7 @@ class Scatter : public Primitive { Stream stream, ReduceType reduce_type, const std::vector& axes) - : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1369,17 +1275,14 @@ class Scatter : public Primitive { std::vector axes_; }; -class Sigmoid : public Primitive { +class Sigmoid : public UnaryPrimitive { public: - explicit Sigmoid(Stream stream) : Primitive(stream){}; + explicit Sigmoid(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1388,17 +1291,14 @@ class Sigmoid : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Sign : public Primitive { +class Sign : public UnaryPrimitive { public: - explicit Sign(Stream stream) : Primitive(stream){}; + explicit Sign(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1407,17 +1307,14 @@ class Sign : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Sin : public Primitive { +class Sin : public UnaryPrimitive { public: - explicit Sin(Stream stream) : Primitive(stream){}; + explicit Sin(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1426,17 +1323,14 @@ class Sin : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Sinh : public Primitive { +class Sinh : public UnaryPrimitive { public: - explicit Sinh(Stream stream) : Primitive(stream){}; + explicit Sinh(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1445,14 +1339,14 @@ class Sinh : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Slice : public Primitive { +class Slice : public UnaryPrimitive { public: explicit Slice( Stream stream, const std::vector& start_indices, const std::vector& end_indices, const std::vector& strides) - : Primitive(stream), + : UnaryPrimitive(stream), start_indices_(start_indices), end_indices_(end_indices), strides_(strides){}; @@ -1460,10 +1354,7 @@ class Slice : public Primitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Slice) bool is_equivalent(const Primitive& other) const override; @@ -1476,17 +1367,14 @@ class Slice : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Softmax : public Primitive { +class Softmax : public UnaryPrimitive { public: - explicit Softmax(Stream stream) : Primitive(stream){}; + explicit Softmax(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Softmax) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1495,17 +1383,15 @@ class Softmax : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Sort : public Primitive { +class Sort : public UnaryPrimitive { public: - explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; + explicit Sort(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sort) bool is_equivalent(const Primitive& other) const override; @@ -1516,17 +1402,14 @@ class Sort : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Square : public Primitive { +class Square : public UnaryPrimitive { public: - explicit Square(Stream stream) : Primitive(stream){}; + explicit Square(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Square) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1535,18 +1418,15 @@ class Square : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Sqrt : public Primitive { +class Sqrt : public UnaryPrimitive { public: explicit Sqrt(Stream stream, bool recip = false) - : Primitive(stream), recip_(recip){}; + : UnaryPrimitive(stream), recip_(recip){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Sqrt) bool is_equivalent(const Primitive& other) const override; @@ -1556,17 +1436,14 @@ class Sqrt : public Primitive { bool recip_; }; -class StopGradient : public Primitive { +class StopGradient : public UnaryPrimitive { public: - explicit StopGradient(Stream stream) : Primitive(stream){}; + explicit StopGradient(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_PRINT(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1574,17 +1451,14 @@ class StopGradient : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Subtract : public Primitive { +class Subtract : public UnaryPrimitive { public: - explicit Subtract(Stream stream) : Primitive(stream){}; + explicit Subtract(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1593,17 +1467,14 @@ class Subtract : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Tan : public Primitive { +class Tan : public UnaryPrimitive { public: - explicit Tan(Stream stream) : Primitive(stream){}; + explicit Tan(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1612,17 +1483,14 @@ class Tan : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Tanh : public Primitive { +class Tanh : public UnaryPrimitive { public: - explicit Tanh(Stream stream) : Primitive(stream){}; + explicit Tanh(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1631,17 +1499,14 @@ class Tanh : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Uniform : public Primitive { +class Uniform : public UnaryPrimitive { public: - explicit Uniform(Stream stream) : Primitive(stream){}; + explicit Uniform(Stream stream) : UnaryPrimitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_PRINT(Uniform) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1649,18 +1514,15 @@ class Uniform : public Primitive { void eval(const std::vector& inputs, array& out); }; -class Transpose : public Primitive { +class Transpose : public UnaryPrimitive { public: explicit Transpose(Stream stream, const std::vector& axes) - : Primitive(stream), axes_(axes){}; + : UnaryPrimitive(stream), axes_(axes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::pair vmap( - const std::vector& inputs, - const std::vector& axes) override; - + DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Transpose) bool is_equivalent(const Primitive& other) const override; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 67e4731ad..a46171443 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include #include @@ -26,6 +25,16 @@ namespace mlx::core { int detail::InTracing::tracing_counter{0}; void simplify(const std::vector& outputs) { + // Some notes about how this function works + // + // Step 1: Traverse the graph and build a tape. During the graph + // traversal we: + // - Build a map of inputs to their parents. + // - Record scalar inputs in a map in order to fuse them. + // Step 2: Process the tape. A node in the tape has inputs and outputs. + // - Scalar inputs are replaced with their canonical scalar + // - We check each inputs output nodes. Every output node that matches + // the current node gets fused into the current node. std::function recurse; std::queue tape; std::unordered_set cache; @@ -54,7 +63,7 @@ void simplify(const std::vector& outputs) { return std::make_pair(v, a.dtype().val); }; - // DFS the graph to log the parents + // DFS the graph to build the tape, and log parents and scalars recurse = [&](const array& a) { auto id = a.id(); if (cache.find(id) != cache.end()) { @@ -63,9 +72,16 @@ void simplify(const std::vector& outputs) { for (int i = 0; i < a.inputs().size(); i++) { auto& in = a.inputs()[i]; parents_map[in.id()].push_back({a, i}); + for (auto& s : a.siblings()) { + parents_map[in.id()].push_back({s, i}); + } recurse(in); } cache.insert(id); + for (auto& s : a.siblings()) { + cache.insert(s.id()); + } + tape.push(a); if (is_scalar(a)) { scalars.insert({get_scalar_rep(a), a}); @@ -78,26 +94,33 @@ void simplify(const std::vector& outputs) { // Helper that fuses two arrays in the graph by setting the parents of the // source to point to the destination auto fuse = [&](array& dst, array& src) { - auto src_parents = parents_map.find(src.id()); - if (src_parents == parents_map.end()) { - return; - } - - auto& pairs = parents_map[dst.id()]; - for (auto& parent : src_parents->second) { - parent.first.editable_inputs()[parent.second] = dst; - pairs.push_back(parent); + // Canonicalize the order of the primitives outputs + auto sources = src.outputs(); + auto dests = dst.outputs(); + // For each src parent, point it to the corresponding dest + for (int i = 0; i < sources.size(); ++i) { + auto src_parents = parents_map.find(sources[i].id()); + if (src_parents == parents_map.end()) { + continue; + } + auto& pairs = parents_map[dests[i].id()]; + for (auto& parent : src_parents->second) { + parent.first.inputs()[parent.second] = dests[i]; + pairs.push_back(parent); + } + // Remove the source from the map to avoid fusing with it again + parents_map.erase(src_parents); } }; - // Walk the graph - cache.clear(); - // Depth-1 array equivalence check. auto array_equivalent = [](const array& a, const array& b) { if (!a.has_primitive() || !b.has_primitive()) { return false; } + if (a.primitive_id() == b.primitive_id()) { + return false; + } const auto& pa = a.primitive(); const auto& pb = b.primitive(); if (typeid(pa) != typeid(pb)) { @@ -117,14 +140,11 @@ void simplify(const std::vector& outputs) { return pa.is_equivalent(pb); }; + // Walk the graph while (!tape.empty()) { auto arr = std::move(tape.front()); tape.pop(); - if (cache.find(arr.id()) != cache.end()) { - continue; - } - // Check if we can fuse scalars if (is_scalar(arr)) { auto scalar = scalars.find(get_scalar_rep(arr)); @@ -134,28 +154,35 @@ void simplify(const std::vector& outputs) { } } - // Check if we can fuse the parents of this array - auto parents = parents_map.find(arr.id()); - if (parents != parents_map.end()) { - std::vector mask(parents->second.size(), false); - auto N = parents->second.size(); - for (int i = 0; i < N; i++) { - if (mask[i]) { - continue; - } - for (int j = i + 1; j < N; j++) { - if (mask[j]) { + // Helper to check if we can fuse the parents of the + // given array + auto maybe_fuse_parents = [&](auto& a) { + auto parents = parents_map.find(a.id()); + if (parents != parents_map.end()) { + auto N = parents->second.size(); + std::vector mask(N, false); + for (int i = 0; i < N; i++) { + if (mask[i]) { continue; } - auto& src = parents->second[j].first; - auto& dst = parents->second[i].first; - if (src.id() != dst.id() && array_equivalent(src, dst)) { - cache.insert(src.id()); - fuse(dst, src); - mask[j] = true; + for (int j = i + 1; j < N; j++) { + if (mask[j]) { + continue; + } + auto& src = parents->second[j].first; + auto& dst = parents->second[i].first; + if (src.id() != dst.id() && array_equivalent(src, dst)) { + fuse(dst, src); + mask[j] = true; + } } } } + }; + + maybe_fuse_parents(arr); + for (auto& s : arr.siblings()) { + maybe_fuse_parents(s); } } } @@ -177,11 +204,14 @@ void eval(const std::vector& outputs) { // stream, we need to manage the dependency. if (!in.is_evaled()) { if (a.primitive().stream() != in.primitive().stream()) { - deps.insert({in.id(), std::shared_future{}}); + deps.insert({in.primitive_id(), std::shared_future{}}); } } } cache.insert(id); + for (auto& s : a.siblings()) { + cache.insert(s.id()); + } if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) { if (!a.has_primitive()) { throw std::invalid_argument( @@ -191,17 +221,23 @@ void eval(const std::vector& outputs) { } }; + // We have to store the output primitive ids because the arrays are + // detached during eval and we need to use them for synchronization + // at the end of this function + std::vector output_primitive_ids; for (auto& arr : outputs) { if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) { recurse(arr); // Insert a dependency for every output to synchronize // with at the end. - if (!arr.is_evaled()) { - deps.insert({arr.id(), std::shared_future{}}); + if (!arr.is_evaled() && deps.find(arr.primitive_id()) == deps.end()) { + deps.insert({arr.primitive_id(), std::shared_future{}}); + output_primitive_ids.push_back(arr.primitive_id()); } } } + std::vector>> ps; while (!tape.empty()) { auto arr = std::move(tape.front()); tape.pop(); @@ -215,13 +251,14 @@ void eval(const std::vector& outputs) { auto stream = arr.primitive().stream(); std::vector> arr_deps; for (auto& in : arr.inputs()) { - if (auto it = deps.find(in.id()); it != deps.end()) { + if (auto it = deps.find(in.primitive_id()); it != deps.end()) { arr_deps.push_back(it->second); } } std::shared_ptr> p{nullptr}; - if (auto it = deps.find(arr.id()); it != deps.end()) { + if (auto it = deps.find(arr.primitive_id()); it != deps.end()) { p = std::make_unique>(); + ps.push_back(p); it->second = p->get_future().share(); } @@ -234,15 +271,19 @@ void eval(const std::vector& outputs) { } else { auto task = [arr, stream, - arr_deps = std::move(arr_deps), + deps = std::move(arr_deps), p = std::move(p)]() mutable { - for (auto& d : arr_deps) { + for (auto& d : deps) { d.wait(); } scheduler::notify_new_task(stream); - arr.primitive().eval_cpu(arr.inputs(), arr); + auto outputs = arr.outputs(); + arr.primitive().eval_cpu(arr.inputs(), outputs); if (!arr.is_tracer()) { arr.detach(); + for (auto s : arr.siblings()) { + s.detach(); + } } if (p) { p->set_value(); @@ -252,10 +293,8 @@ void eval(const std::vector& outputs) { scheduler::enqueue(stream, std::move(task)); } } - for (auto& arr : outputs) { - if (auto it = deps.find(arr.id()); it != deps.end()) { - it->second.wait(); - } + for (auto id : output_primitive_ids) { + deps[id].wait(); } } @@ -301,8 +340,8 @@ std::pair, std::vector> vjp( output_cotan_pairs.emplace_back(i, cotan_index++); } - // Topologically sort the compute graph, record outputs - // in the tape if a gradient is needed. + // Topologically sort the compute graph, add graph nodes + // to the tape which need a gradient. std::unordered_set cache; std::unordered_set calc_grad; for (auto& primal : primals_) { @@ -315,34 +354,41 @@ std::pair, std::vector> vjp( std::function recurse; recurse = [&](auto& a) { - auto id = a.id(); - a.set_tracer(false); - // Check if visited and add to cache if not - if (auto inserted = cache.insert(id); !inserted.second) { + if (auto inserted = cache.insert(a.id()); !inserted.second) { return; } + a.set_tracer(false); + for (auto s : a.siblings()) { + s.set_tracer(false); + cache.insert(s.id()); + } - for (auto& input : a.editable_inputs()) { + for (auto& input : a.inputs()) { recurse(input); } // Stop grad - if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) { - return; + if (a.has_primitive()) { + if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) { + return; + } } // Calculate gradient if any inputs require gradient for (auto& input : a.inputs()) { if (calc_grad.find(input.id()) != calc_grad.end()) { tape.push_back(a); - calc_grad.insert(id); + calc_grad.insert(a.id()); + for (auto& s : a.siblings()) { + calc_grad.insert(s.id()); + } break; } } }; - for (auto& out : outputs) { + for (auto out : outputs) { recurse(out); } @@ -363,14 +409,28 @@ std::pair, std::vector> vjp( } } - auto cotan_it = cotan_map.find(a.id()); - if (cotan_it == cotan_map.end()) { + // Check if any of the array or its siblings have cotangents, + // if not, we can skip this primitive + auto outputs = a.outputs(); + bool has_cotans = + std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) { + return cotan_map.find(s.id()) != cotan_map.end(); + }); + if (!has_cotans) { continue; } - auto cotan = cotan_map.extract(cotan_it).mapped(); - auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums); auto s = a.primitive().stream(); + std::vector cotangents{}; + for (auto& o : outputs) { + if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) { + cotangents.push_back(cotan_map.extract(cotan_it).mapped()); + } else { + cotangents.push_back(zeros_like(o, s)); + } + } + + auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums); // Accumulate the vector-jacobian products for each input for (int i = 0; i < argnums.size(); ++i) { auto in_id = a.inputs()[argnums[i]].id(); @@ -411,6 +471,9 @@ std::pair, std::vector> jvp( const std::function(const std::vector&)>& fun, const std::vector& primals, const std::vector& tangents) { + // Set the global tracing flag. + detail::InTracing in_tracing; + if (primals.size() != tangents.size()) { throw std::invalid_argument( "[jvp] Number of inputs does not match number of tangents."); @@ -422,9 +485,6 @@ std::pair, std::vector> jvp( } } - // Set the global tracing flag. - detail::InTracing in_tracing; - std::vector primals_; for (auto& p : primals) { auto s = p.has_primitive() ? p.primitive().stream() @@ -448,36 +508,44 @@ std::pair, std::vector> jvp( std::function recurse; recurse = [&](auto& a) { - auto id = a.id(); - a.set_tracer(false); - // Check if visited and add to cache if not - if (auto inserted = cache.insert(id); !inserted.second) { + if (auto inserted = cache.insert(a.id()); !inserted.second) { return; } + a.set_tracer(false); + for (auto s : a.siblings()) { + s.set_tracer(false); + cache.insert(s.id()); + } - for (auto& input : a.editable_inputs()) { + for (auto input : a.inputs()) { recurse(input); } // Stop grad - if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) { - return; + if (a.has_primitive()) { + if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) { + return; + } } // Calculate gradient if any inputs require gradient for (auto& input : a.inputs()) { if (calc_grad.find(input.id()) != calc_grad.end()) { tape.push_back(a); - calc_grad.insert(id); + calc_grad.insert(a.id()); + for (auto& s : a.siblings()) { + calc_grad.insert(s.id()); + } break; } } }; - for (auto& out : outputs) { + for (auto out : outputs) { recurse(out); } + std::unordered_map tan_map; for (int i = 0; i < primals_.size(); ++i) { tan_map.insert({primals_[i].id(), tangents[i]}); @@ -494,8 +562,11 @@ std::pair, std::vector> jvp( } } - auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums); - tan_map.insert({a.id(), jvp}); + auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums); + auto outputs = a.outputs(); + for (int i = 0; i < jvps.size(); ++i) { + tan_map.insert({outputs[i].id(), jvps[i]}); + } } std::vector jvps; @@ -578,8 +649,8 @@ std::pair, std::vector> vmap_trace( const std::function(const std::vector&)>& fun, const std::vector& inputs, const std::vector& in_axes) { - // Set the global tracing flag - InTracing in_tracing; + // Set the global tracing flag. + detail::InTracing in_tracing; if (in_axes.size() != inputs.size()) { throw std::invalid_argument( @@ -627,48 +698,56 @@ std::vector vmap_replace( std::unordered_map> tmap; std::unordered_set needs_vmap; - for (int i = 0; i < s_inputs.size(); ++i) { - if (in_axes[i] != -1) { - tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}}); - needs_vmap.insert(s_inputs[i].id()); - } - } - - // Topologically sort the graph std::unordered_set cache; for (int i = 0; i < s_inputs.size(); ++i) { auto in = s_inputs[i]; if (in_axes[i] != -1) { + tmap.insert({in.id(), {inputs[i], in_axes[i]}}); + needs_vmap.insert(in.id()); in.set_tracer(false); } cache.insert(in.id()); } + + // Topologically sort the graph std::vector tape; std::function recurse; recurse = [&](const array& a) { - // Stop at inputs to the vmap function auto id = a.id(); if (cache.find(id) != cache.end()) { return; } + cache.insert(id); + for (auto& s : a.siblings()) { + cache.insert(s.id()); + } + + // Recurse on inputs for (auto& input : a.inputs()) { recurse(input); } - cache.insert(id); + // If any input needs a vmap, then the outputs also need + // a vmap for (auto& input : a.inputs()) { if (needs_vmap.find(input.id()) != needs_vmap.end()) { - needs_vmap.insert(id); tape.push_back(a); tape.back().set_tracer(false); + needs_vmap.insert(a.id()); + for (auto s : a.siblings()) { + needs_vmap.insert(s.id()); + s.set_tracer(false); + } break; } } }; for (auto& out : s_outputs) { - recurse(out); + if (out.has_primitive()) { + recurse(out); + } } // Transform each primitive in the graph with @@ -686,16 +765,19 @@ std::vector vmap_replace( v_axes.push_back(-1); } } - auto out_and_axis = a.primitive().vmap(v_inputs, v_axes); - tmap.insert({a.id(), out_and_axis}); + auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes); + // For each primitive's outputs add its id, the vout id and the vax + auto outputs = a.outputs(); + for (int i = 0; i < v_outputs.size(); ++i) { + tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}}); + } } // Populate the outputs and make sure all the output axes are // in the right place std::vector outputs; for (int i = 0; i < s_outputs.size(); ++i) { - auto map_it = tmap.find(s_outputs[i].id()); - if (map_it != tmap.end()) { + if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) { auto& [out, vdim] = map_it->second; if (vdim != out_axes[i]) { if (out_axes[i] >= out.ndim()) { @@ -704,11 +786,7 @@ std::vector vmap_replace( << out.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - std::vector reorder(out.ndim()); - std::iota(reorder.begin(), reorder.end(), 0); - reorder.erase(reorder.begin() + vdim); - reorder.insert(reorder.begin() + out_axes[i], vdim); - out = transpose(out, reorder); + out = moveaxis(out, vdim, out_axes[i]); } outputs.push_back(out); } else { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1b865fbd1..540c6e99a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -303,6 +303,33 @@ void init_ops(py::module_& m) { Returns: array: The quotient ``a / b``. )pbdoc"); + m.def( + "divmod", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return divmod(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array + + Element-wise quotient and remainder. + + The fuction ``divmod(a, b)`` is equivalent to but faster than + ``(a // b, a % b)``. The function uses numpy-style broadcasting + semantics. Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + tuple(array, array): The quotient ``a // b`` and remainder ``a % b``. + )pbdoc"); m.def( "floor_divide", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { diff --git a/python/tests/test_graph.py b/python/tests/test_graph.py new file mode 100644 index 000000000..4b8f6d86a --- /dev/null +++ b/python/tests/test_graph.py @@ -0,0 +1,37 @@ +# Copyright © 2023 Apple Inc. + +import io +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestGraph(mlx_tests.MLXTestCase): + def test_to_dot(self): + # Simply test that a few cases run. + # Nothing too specific about the graph format + # for now to keep it flexible + a = mx.array(1.0) + f = io.StringIO() + mx.export_to_dot(f, a) + f.seek(0) + self.assertTrue(len(f.read()) > 0) + + b = mx.array(2.0) + c = a + b + f = io.StringIO() + mx.export_to_dot(f, c) + f.seek(0) + self.assertTrue(len(f.read()) > 0) + + # Multi output case + c = mx.divmod(a, b) + f = io.StringIO() + mx.export_to_dot(f, *c) + f.seek(0) + self.assertTrue(len(f.read()) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 13f814fc1..433188b9a 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1314,7 +1314,7 @@ class TestOps(mlx_tests.MLXTestCase): for axis in (None, 0, 1, 2): c_npy = npop(a_npy, axis=axis) c_mlx = mxop(a_mlx, axis=axis) - self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-4, atol=1e-4)) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) for op in ["cumsum", "cumprod", "cummax", "cummin"]: c1 = mxop(a_mlx, axis=2) @@ -1597,6 +1597,28 @@ class TestOps(mlx_tests.MLXTestCase): np.outer, ) + def test_divmod(self): + # A few sizes for the inputs with and without broadcasting + sizes = [ + ((1,), (1,)), + ((1,), (10,)), + ((10,), (1,)), + ((3,), (3,)), + ((2, 2, 2), (1, 2, 1)), + ((2, 1, 2), (1, 2, 1)), + ((2, 2, 2, 2), (2, 2, 2, 2)), + ] + types = [np.uint16, np.uint32, np.int32, np.float16, np.float32] + for s1, s2 in sizes: + for t in types: + a_np = np.random.uniform(1, 100, size=s1).astype(t) + b_np = np.random.uniform(1, 100, size=s2).astype(t) + np_out = np.divmod(a_np, b_np) + mx_out = mx.divmod(mx.array(a_np), mx.array(b_np)) + self.assertTrue( + np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}" + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/graph_optimize_tests.cpp b/tests/graph_optimize_tests.cpp index 8c2da162c..d30005fc6 100644 --- a/tests/graph_optimize_tests.cpp +++ b/tests/graph_optimize_tests.cpp @@ -7,19 +7,29 @@ using namespace mlx::core; TEST_CASE("test simplify scalars") { - auto a = array({-1.0f, 2.0f}); - auto b = maximum(a, array(0.0f)); - auto c = maximum(-a, array(0.0f)); - auto d = b + c; - simplify({d}); - CHECK(b.inputs()[1].id() == c.inputs()[1].id()); + { + auto a = array(-1.0f); + auto b = array(-1.0f); + auto c = abs(a); + auto d = abs(b); + simplify({c, d}); + CHECK(c.inputs()[0].id() == d.inputs()[0].id()); + } + + { + auto a = array({-1.0f, 2.0f}); + auto b = maximum(a, array(0.0f)); + auto c = maximum(-a, array(0.0f)); + auto d = b + c; + simplify({d}); + CHECK(b.inputs()[1].id() == c.inputs()[1].id()); + } } TEST_CASE("test simplify") { auto a = array({1.0f, 2.0f}); auto b = exp(a) + exp(a); simplify(b); - eval(b); CHECK(b.inputs()[0].id() == b.inputs()[1].id()); } @@ -27,6 +37,44 @@ TEST_CASE("test no simplify") { auto a = array({1.0f, 2.0f}); auto b = cos(a) + sin(a); simplify(b); - eval(b); CHECK(b.inputs()[0].id() != b.inputs()[1].id()); } + +TEST_CASE("test simplify multi output") { + { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = c[0] + d[0]; + auto f = c[1] + d[1]; + + simplify({e, f}); + CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id()); + CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id()); + } + + { + auto a = array(1.0); + auto b = array(1.0); + auto c = divmod(a, b); + simplify(c); + CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id()); + CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id()); + CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id()); + } + + // Make sure the output order of multi-output primitives + // is respected in simplification + { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = stack({c[0], c[1], d[0], d[1]}); + simplify(e); + CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item()); + CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); + CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); + } +} diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 1d8ee43d9..1bf02c243 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -30,7 +30,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { CHECK_EQ(norm(x).item(), doctest::Approx(expected)); CHECK_EQ( norm(x, std::vector{0, 1}).item(), doctest::Approx(expected)); - CHECK(array_equal( + CHECK(allclose( norm(x, 0, false), array( {std::sqrt(0 + 3 * 3 + 6 * 6), diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 0521d9c25..70b4c82e4 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2397,3 +2397,58 @@ TEST_CASE("inner") { expected = array({7., 0., 0., 7.}, {2, 2}); CHECK(array_equal(z, expected).item()); } + +TEST_CASE("test divmod") { + auto x = array({1, 2, 3}); + auto y = array({1, 1, 1}); + auto out = divmod(x, y); + CHECK(array_equal(out[0], array({1, 2, 3})).item()); + CHECK(array_equal(out[1], array({0, 0, 0})).item()); + + x = array({5, 6, 7}); + y = array({2, 2, 2}); + out = divmod(x, y); + CHECK(array_equal(out[0], array({2, 3, 3})).item()); + CHECK(array_equal(out[1], array({1, 0, 1})).item()); + + // Siblings should be gone after evaling the graph + CHECK(out[0].siblings().empty()); + CHECK(out[1].siblings().empty()); + + x = array({5.0, 6.0, 7.0}); + y = array({2.0, 2.0, 2.0}); + out = divmod(x, y); + CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item()); + CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item()); + + x = array({1.0}, complex64); + y = array({2.0}, complex64); + CHECK_THROWS(divmod(x, y)); + + // Check that we can eval on both outputs + x = array({1.0}); + y = array({2.0}); + out = divmod(x, y); + eval(out); + CHECK_EQ(out[0].item(), 0.0); + CHECK_EQ(out[1].item(), 1.0); + + // Check nested in the graph + x = array({1.0}); + y = array({2.0}); + out = divmod(x, y); + auto z = out[0] + out[1]; + CHECK_EQ(z.item(), 1.0); + + // Check that we can still eval when one output goes out of scope + std::vector out_holder; + { out_holder.push_back(divmod(x, y)[0]); } + eval(out_holder); + CHECK_EQ(out_holder[0].item(), 0.0); + + // Check that we can still eval when the other output goes out of scope + out_holder.clear(); + { out_holder.push_back(divmod(x, y)[1]); } + eval(out_holder); + CHECK_EQ(out_holder[0].item(), 1.0); +}