MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_set>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/io/load.h"
10#include "mlx/stream.h"
11
12#define DEFINE_VMAP() \
13 virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14 const std::vector<array>& inputs, const std::vector<int>& axes) \
15 override;
16
17#define DEFINE_GRADS() \
18 std::vector<array> jvp( \
19 const std::vector<array>& primals, \
20 const std::vector<array>& tangents, \
21 const std::vector<int>& argnums) override; \
22 \
23 std::vector<array> vjp( \
24 const std::vector<array>& primals, \
25 const std::vector<array>& cotangents, \
26 const std::vector<int>& argnums, \
27 const std::vector<array>& outputs) override;
28
29#define DEFINE_PRINT(PRIMITIVE) \
30 void print(std::ostream& os) override { \
31 os << #PRIMITIVE; \
32 }
33
34#define DEFINE_DEFAULT_IS_EQUIVALENT() \
35 bool is_equivalent(const Primitive& other) const override { \
36 return true; \
37 }
38
39#define DEFINE_INPUT_OUTPUT_SHAPE() \
40 std::vector<std::vector<int>> output_shapes( \
41 const std::vector<array>& inputs) override { \
42 return {inputs[0].shape()}; \
43 };
44
45namespace mlx::core {
46
47// Abstract base class
48class Primitive {
49 public:
50 explicit Primitive(Stream stream) : stream_(stream) {}
51
53 const Device& device() {
54 return stream().device;
55 }
56
58 const Stream& stream() {
59 return stream_;
60 }
61
69 virtual void eval_cpu(
70 const std::vector<array>& inputs,
71 std::vector<array>& outputs) = 0;
72 virtual void eval_gpu(
73 const std::vector<array>& inputs,
74 std::vector<array>& outputs) = 0;
75
79 virtual std::vector<array> jvp(
80 const std::vector<array>& primals,
81 const std::vector<array>& tangents,
82 const std::vector<int>& argnums);
83
87 virtual std::vector<array> vjp(
88 const std::vector<array>& primals,
89 const std::vector<array>& cotangents,
90 const std::vector<int>& argnums,
91 const std::vector<array>& outputs);
92
99 virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100 const std::vector<array>& inputs,
101 const std::vector<int>& axes);
102
104 virtual void print(std::ostream& os) = 0;
105
107 virtual bool is_equivalent(const Primitive& other) const {
108 return false;
109 }
110
113 virtual std::vector<std::vector<int>> output_shapes(
114 const std::vector<array>& inputs);
115
116 virtual ~Primitive() = default;
117 Primitive(const Primitive& other) = delete;
118 Primitive(Primitive&& other) = delete;
119 Primitive& operator=(const Primitive& other) = delete;
120 Primitive& operator=(Primitive&& other) = delete;
121
122 private:
123 // Every primitive stores the stream it should run in
124 Stream stream_;
125};
126
127class UnaryPrimitive : public Primitive {
131 public:
133
134 virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
135 virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
136
137 inline void eval_cpu(
138 const std::vector<array>& inputs,
139 std::vector<array>& outputs) override {
140 eval_cpu(inputs, outputs[0]);
141 }
142 inline void eval_gpu(
143 const std::vector<array>& inputs,
144 std::vector<array>& outputs) override {
145 eval_gpu(inputs, outputs[0]);
146 }
147
148 virtual ~UnaryPrimitive() = default;
149 UnaryPrimitive(const UnaryPrimitive& other) = delete;
151 UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
153};
154
155class Abs : public UnaryPrimitive {
156 public:
158
159 void eval_cpu(const std::vector<array>& inputs, array& out) override;
160 void eval_gpu(const std::vector<array>& inputs, array& out) override;
161
167
168 private:
169 void eval(const std::vector<array>& inputs, array& out);
170};
171
172class Add : public UnaryPrimitive {
173 public:
175
176 void eval_cpu(const std::vector<array>& inputs, array& out) override;
177 void eval_gpu(const std::vector<array>& inputs, array& out) override;
178
184
185 private:
186 void eval(const std::vector<array>& inputs, array& out);
187};
188
189class AddMM : public UnaryPrimitive {
190 public:
191 explicit AddMM(Stream stream, float alpha, float beta)
192 : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {};
193
194 void eval_cpu(const std::vector<array>& inputs, array& out) override;
195 void eval_gpu(const std::vector<array>& inputs, array& out) override;
196
197 std::vector<array> vjp(
198 const std::vector<array>& primals,
199 const std::vector<array>& cotangents,
200 const std::vector<int>& argnums,
201 const std::vector<array>& outputs) override;
202
205
206 bool is_equivalent(const Primitive& other) const override;
207
208 private:
209 const float alpha_;
210 const float beta_;
211};
212
213class Arange : public UnaryPrimitive {
214 public:
215 explicit Arange(Stream stream, double start, double stop, double step)
216 : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {};
217
218 void eval_cpu(const std::vector<array>& inputs, array& out) override;
219 void eval_gpu(const std::vector<array>& inputs, array& out) override;
220
222 bool is_equivalent(const Primitive& other) const override;
223
224 private:
225 double start_;
226 double stop_;
227 double step_;
228
229 void eval(const std::vector<array>& inputs, array& out);
230};
231
232class ArcCos : public UnaryPrimitive {
233 public:
235
236 void eval_cpu(const std::vector<array>& inputs, array& out) override;
237 void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
244
245 private:
246 void eval(const std::vector<array>& inputs, array& out);
247};
248
249class ArcCosh : public UnaryPrimitive {
250 public:
252
253 void eval_cpu(const std::vector<array>& inputs, array& out) override;
254 void eval_gpu(const std::vector<array>& inputs, array& out) override;
255
261
262 private:
263 void eval(const std::vector<array>& inputs, array& out);
264};
265
266class ArcSin : public UnaryPrimitive {
267 public:
269
270 void eval_cpu(const std::vector<array>& inputs, array& out) override;
271 void eval_gpu(const std::vector<array>& inputs, array& out) override;
272
278
279 private:
280 void eval(const std::vector<array>& inputs, array& out);
281};
282
283class ArcSinh : public UnaryPrimitive {
284 public:
286
287 void eval_cpu(const std::vector<array>& inputs, array& out) override;
288 void eval_gpu(const std::vector<array>& inputs, array& out) override;
289
295
296 private:
297 void eval(const std::vector<array>& inputs, array& out);
298};
299
300class ArcTan : public UnaryPrimitive {
301 public:
303
304 void eval_cpu(const std::vector<array>& inputs, array& out) override;
305 void eval_gpu(const std::vector<array>& inputs, array& out) override;
306
312
313 private:
314 void eval(const std::vector<array>& inputs, array& out);
315};
316
317class ArcTan2 : public UnaryPrimitive {
318 public:
320
321 void eval_cpu(const std::vector<array>& inputs, array& out) override;
322 void eval_gpu(const std::vector<array>& inputs, array& out) override;
323
329
330 private:
331 void eval(const std::vector<array>& inputs, array& out);
332};
333
334class ArcTanh : public UnaryPrimitive {
335 public:
337
338 void eval_cpu(const std::vector<array>& inputs, array& out) override;
339 void eval_gpu(const std::vector<array>& inputs, array& out) override;
340
346
347 private:
348 void eval(const std::vector<array>& inputs, array& out);
349};
350
352 public:
353 explicit ArgPartition(Stream stream, int kth, int axis)
354 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
355
356 void eval_cpu(const std::vector<array>& inputs, array& out) override;
357 void eval_gpu(const std::vector<array>& inputs, array& out) override;
358
362 bool is_equivalent(const Primitive& other) const override;
363
364 private:
365 int kth_;
366 int axis_;
367
368 void eval(const std::vector<array>& inputs, array& out);
369};
370
371class ArgReduce : public UnaryPrimitive {
372 public:
377
378 explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
379 : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {};
380
381 void eval_cpu(const std::vector<array>& inputs, array& out) override;
382 void eval_gpu(const std::vector<array>& inputs, array& out) override;
383
386 bool is_equivalent(const Primitive& other) const override;
387 std::vector<std::vector<int>> output_shapes(
388 const std::vector<array>& inputs) override;
389
390 private:
391 ReduceType reduce_type_;
392 int axis_;
393
394 void eval(const std::vector<array>& inputs, array& out);
395};
396
397class ArgSort : public UnaryPrimitive {
398 public:
399 explicit ArgSort(Stream stream, int axis)
400 : UnaryPrimitive(stream), axis_(axis) {};
401
402 void eval_cpu(const std::vector<array>& inputs, array& out) override;
403 void eval_gpu(const std::vector<array>& inputs, array& out) override;
404
408 bool is_equivalent(const Primitive& other) const override;
409
410 private:
411 int axis_;
412
413 void eval(const std::vector<array>& inputs, array& out);
414};
415
416class AsType : public UnaryPrimitive {
417 public:
418 explicit AsType(Stream stream, Dtype dtype)
419 : UnaryPrimitive(stream), dtype_(dtype) {};
420
421 void eval_cpu(const std::vector<array>& inputs, array& out) override;
422 void eval_gpu(const std::vector<array>& inputs, array& out) override;
423
428 bool is_equivalent(const Primitive& other) const override;
429
430 private:
431 Dtype dtype_;
432
433 void eval(const std::vector<array>& inputs, array& out);
434};
435
436class AsStrided : public UnaryPrimitive {
437 public:
438 explicit AsStrided(
440 std::vector<int> shape,
441 std::vector<size_t> strides,
442 size_t offset)
444 shape_(std::move(shape)),
445 strides_(std::move(strides)),
446 offset_(offset) {};
447
448 void eval_cpu(const std::vector<array>& inputs, array& out) override;
449 void eval_gpu(const std::vector<array>& inputs, array& out) override;
450
453 bool is_equivalent(const Primitive& other) const override;
454
455 private:
456 std::vector<int> shape_;
457 std::vector<size_t> strides_;
458 size_t offset_;
459
460 void eval(const std::vector<array>& inputs, array& out);
461};
462
464 public:
465 enum Op { And, Or, Xor, LeftShift, RightShift };
466
468 : UnaryPrimitive(stream), op_(op) {};
469
470 void eval_cpu(const std::vector<array>& inputs, array& out) override;
471 void eval_gpu(const std::vector<array>& inputs, array& out) override;
472
474 bool is_equivalent(const Primitive& other) const override;
475 void print(std::ostream& os) override;
477
478 private:
479 Op op_;
480};
481
483 public:
484 explicit BlockMaskedMM(Stream stream, int block_size)
485 : UnaryPrimitive(stream), block_size_(block_size) {};
486
487 void eval_cpu(const std::vector<array>& inputs, array& out) override;
488 void eval_gpu(const std::vector<array>& inputs, array& out) override;
489
490 std::vector<array> vjp(
491 const std::vector<array>& primals,
492 const std::vector<array>& cotangents,
493 const std::vector<int>& argnums,
494 const std::vector<array>& outputs) override;
495
497 bool is_equivalent(const Primitive& other) const override;
498
499 private:
500 int block_size_;
501
502 void eval(const std::vector<array>& inputs, array& out);
503};
504
505class GatherMM : public UnaryPrimitive {
506 public:
508
509 void eval_cpu(const std::vector<array>& inputs, array& out) override;
510 void eval_gpu(const std::vector<array>& inputs, array& out) override;
511
512 std::vector<array> vjp(
513 const std::vector<array>& primals,
514 const std::vector<array>& cotangents,
515 const std::vector<int>& argnums,
516 const std::vector<array>& outputs) override;
517
520
521 private:
522 void eval(const std::vector<array>& inputs, array& out);
523};
524
525class Broadcast : public UnaryPrimitive {
526 public:
527 explicit Broadcast(Stream stream, const std::vector<int>& shape)
528 : UnaryPrimitive(stream), shape_(shape) {};
529
530 void eval_cpu(const std::vector<array>& inputs, array& out) override;
531 void eval_gpu(const std::vector<array>& inputs, array& out) override;
532
536 bool is_equivalent(const Primitive& other) const override;
537
538 private:
539 std::vector<int> shape_;
540
541 void eval(const std::vector<array>& inputs, array& out);
542};
543
544class Ceil : public UnaryPrimitive {
545 public:
547
548 void eval_cpu(const std::vector<array>& inputs, array& out) override;
549 void eval_gpu(const std::vector<array>& inputs, array& out) override;
550
556
557 private:
558 void eval(const std::vector<array>& inputs, array& out);
559};
560
561class Compiled : public Primitive {
562 public:
563 /*
564 * The inputs, outputs and tape are either tracers or constants.
565 * - The tape should not contain the inputs, but it should contain the
566 * outputs.
567 * - The tape should also have only one array per primitive for multi-output
568 * primitives.
569 * - The constant_ids contains ids of arrays in the input list that are safe
570 * to treat as scalar constants.
571 */
572 explicit Compiled(
574 std::vector<array> inputs,
575 std::vector<array> outputs,
576 std::vector<array> tape,
577 std::unordered_set<uintptr_t> constant_ids);
578
579 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
580 override;
581 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
582 override;
583
586 std::vector<std::vector<int>> output_shapes(
587 const std::vector<array>& inputs) override;
588 void print(std::ostream& os) override;
589 bool is_equivalent(const Primitive& other) const override;
590
591 std::string lib_name() const {
592 return kernel_lib_;
593 }
594
595 private:
596 const std::vector<array> inputs_;
597 const std::vector<array> outputs_;
598 const std::vector<array> tape_;
599 const std::unordered_set<uintptr_t> constant_ids_;
600
601 std::string kernel_lib_;
602};
603
605 public:
606 explicit Concatenate(Stream stream, int axis)
607 : UnaryPrimitive(stream), axis_(axis) {};
608
609 void eval_cpu(const std::vector<array>& inputs, array& out) override;
610 void eval_gpu(const std::vector<array>& inputs, array& out) override;
611
615 bool is_equivalent(const Primitive& other) const override;
616
617 private:
618 int axis_;
619
620 void eval(const std::vector<array>& inputs, array& out);
621};
622
623class Conjugate : public UnaryPrimitive {
624 public:
625 explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {};
626
627 void eval_cpu(const std::vector<array>& inputs, array& out) override;
628 void eval_gpu(const std::vector<array>& inputs, array& out) override;
629
634
635 private:
636 void eval(const std::vector<array>& inputs, array& out);
637};
638
640 public:
641 explicit Convolution(
642 Stream stream,
643 const std::vector<int>& kernel_strides,
644 const std::vector<int>& padding,
645 const std::vector<int>& kernel_dilation,
646 const std::vector<int>& input_dilation,
647 const int groups = 1,
648 const bool flip = false)
649 : UnaryPrimitive(stream),
650 padding_(padding),
651 kernel_strides_(kernel_strides),
652 kernel_dilation_(kernel_dilation),
653 input_dilation_(input_dilation),
654 groups_(groups),
655 flip_(flip) {};
656
657 void eval_cpu(const std::vector<array>& inputs, array& out) override;
658 void eval_gpu(const std::vector<array>& inputs, array& out) override;
659
660 std::vector<array> vjp(
661 const std::vector<array>& primals,
662 const std::vector<array>& cotangents,
663 const std::vector<int>& argnums,
664 const std::vector<array>& outputs) override;
665
667 bool is_equivalent(const Primitive& other) const override;
668
669 private:
670 std::vector<int> padding_;
671 std::vector<int> kernel_strides_;
672 std::vector<int> kernel_dilation_;
673 std::vector<int> input_dilation_;
674 int groups_;
675 bool flip_;
676
677 void eval(const std::vector<array>& inputs, array& out);
678};
679
680class Copy : public UnaryPrimitive {
681 public:
682 explicit Copy(Stream stream) : UnaryPrimitive(stream) {};
683
684 void eval_cpu(const std::vector<array>& inputs, array& out) override;
685 void eval_gpu(const std::vector<array>& inputs, array& out) override;
686
692
693 private:
694 void eval(const std::vector<array>& inputs, array& out);
695};
696
697class Cos : public UnaryPrimitive {
698 public:
699 explicit Cos(Stream stream) : UnaryPrimitive(stream) {};
700
701 void eval_cpu(const std::vector<array>& inputs, array& out) override;
702 void eval_gpu(const std::vector<array>& inputs, array& out) override;
703
709
710 private:
711 void eval(const std::vector<array>& inputs, array& out);
712};
713
714class Cosh : public UnaryPrimitive {
715 public:
716 explicit Cosh(Stream stream) : UnaryPrimitive(stream) {};
717
718 void eval_cpu(const std::vector<array>& inputs, array& out) override;
719 void eval_gpu(const std::vector<array>& inputs, array& out) override;
720
726
727 private:
728 void eval(const std::vector<array>& inputs, array& out);
729};
730
731class CustomVJP : public Primitive {
732 public:
733 explicit CustomVJP(
734 Stream stream,
735 std::function<std::vector<array>(
736 const std::vector<array>&,
737 const std::vector<array>&,
738 const std::vector<array>&)> fun)
739 : Primitive(stream), vjp_fun_(std::move(fun)) {}
740
741 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
742 override;
743 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
744 override;
745
746 std::vector<array> vjp(
747 const std::vector<array>& primals,
748 const std::vector<array>& cotan,
749 const std::vector<int>& argnums,
750 const std::vector<array>& outputs) override;
751
753
754 private:
755 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
756
757 std::function<std::vector<array>(
758 const std::vector<array>&,
759 const std::vector<array>&,
760 const std::vector<array>&)>
761 vjp_fun_;
762};
763
764class Depends : public Primitive {
765 public:
766 explicit Depends(Stream stream) : Primitive(stream) {}
767
768 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
769 override;
770 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
771 override;
772
773 std::vector<array> vjp(
774 const std::vector<array>& primals,
775 const std::vector<array>& cotan,
776 const std::vector<int>& argnums,
777 const std::vector<array>& outputs) override;
778
780
781 private:
782 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
783};
784
785class Divide : public UnaryPrimitive {
786 public:
787 explicit Divide(Stream stream) : UnaryPrimitive(stream) {};
788
789 void eval_cpu(const std::vector<array>& inputs, array& out) override;
790 void eval_gpu(const std::vector<array>& inputs, array& out) override;
791
797
798 private:
799 void eval(const std::vector<array>& inputs, array& out);
800};
801
802class DivMod : public Primitive {
803 public:
804 explicit DivMod(Stream stream) : Primitive(stream) {};
805
806 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
807 override;
808 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
809 override;
810
815 std::vector<std::vector<int>> output_shapes(
816 const std::vector<array>& inputs) override {
817 return std::vector{inputs[0].shape(), inputs[0].shape()};
818 };
819
820 private:
821 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
822};
823
824class Select : public UnaryPrimitive {
825 public:
826 explicit Select(Stream stream) : UnaryPrimitive(stream) {};
827
828 void eval_cpu(const std::vector<array>& inputs, array& out) override;
829 void eval_gpu(const std::vector<array>& inputs, array& out) override;
830
836
837 private:
838 void eval(const std::vector<array>& inputs, array& out);
839};
840
841class Remainder : public UnaryPrimitive {
842 public:
843 explicit Remainder(Stream stream) : UnaryPrimitive(stream) {};
844
845 void eval_cpu(const std::vector<array>& inputs, array& out) override;
846 void eval_gpu(const std::vector<array>& inputs, array& out) override;
847
853
854 private:
855 void eval(const std::vector<array>& inputs, array& out);
856};
857
858class Equal : public UnaryPrimitive {
859 public:
860 explicit Equal(Stream stream, bool equal_nan = false)
861 : UnaryPrimitive(stream), equal_nan_(equal_nan) {};
862
863 void eval_cpu(const std::vector<array>& inputs, array& out) override;
864 void eval_gpu(const std::vector<array>& inputs, array& out) override;
865
870
871 void print(std::ostream& os) override {
872 if (equal_nan_) {
873 os << "NaNEqual";
874 } else {
875 os << "Equal";
876 }
877 }
878
879 private:
880 void eval(const std::vector<array>& inputs, array& out);
881 bool equal_nan_;
882};
883
884class Erf : public UnaryPrimitive {
885 public:
886 explicit Erf(Stream stream) : UnaryPrimitive(stream) {};
887
888 void eval_cpu(const std::vector<array>& inputs, array& out) override;
889 void eval_gpu(const std::vector<array>& inputs, array& out) override;
890
896
897 private:
898 void eval(const std::vector<array>& inputs, array& out);
899};
900
901class ErfInv : public UnaryPrimitive {
902 public:
903 explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {};
904
905 void eval_cpu(const std::vector<array>& inputs, array& out) override;
906 void eval_gpu(const std::vector<array>& inputs, array& out) override;
907
913
914 private:
915 void eval(const std::vector<array>& inputs, array& out);
916};
917
918class Exp : public UnaryPrimitive {
919 public:
920 explicit Exp(Stream stream) : UnaryPrimitive(stream) {};
921
922 void eval_cpu(const std::vector<array>& inputs, array& out) override;
923 void eval_gpu(const std::vector<array>& inputs, array& out) override;
924
930
931 private:
932 void eval(const std::vector<array>& inputs, array& out);
933};
934
935class Expm1 : public UnaryPrimitive {
936 public:
937 explicit Expm1(Stream stream) : UnaryPrimitive(stream) {};
938
939 void eval_cpu(const std::vector<array>& inputs, array& out) override;
940 void eval_gpu(const std::vector<array>& inputs, array& out) override;
941
946
947 private:
948 void eval(const std::vector<array>& inputs, array& out);
949};
950
951class FFT : public UnaryPrimitive {
952 public:
953 explicit FFT(
954 Stream stream,
955 const std::vector<size_t>& axes,
956 bool inverse,
957 bool real)
958 : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {};
959
960 void eval_cpu(const std::vector<array>& inputs, array& out) override;
961 void eval_gpu(const std::vector<array>& inputs, array& out) override;
962
966
967 bool is_equivalent(const Primitive& other) const override;
968
969 private:
970 std::vector<size_t> axes_;
971 bool inverse_;
972 bool real_;
973
974 void eval(const std::vector<array>& inputs, array& out);
975};
976
977class Floor : public UnaryPrimitive {
978 public:
979 explicit Floor(Stream stream) : UnaryPrimitive(stream) {};
980
981 void eval_cpu(const std::vector<array>& inputs, array& out) override;
982 void eval_gpu(const std::vector<array>& inputs, array& out) override;
983
989
990 private:
991 void eval(const std::vector<array>& inputs, array& out);
992};
993
994class Full : public UnaryPrimitive {
995 public:
996 explicit Full(Stream stream) : UnaryPrimitive(stream) {};
997
998 void eval_cpu(const std::vector<array>& inputs, array& out) override;
999 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1000
1005
1006 private:
1007 void eval(const std::vector<array>& inputs, array& out);
1008};
1009
1010class Gather : public UnaryPrimitive {
1011 public:
1012 explicit Gather(
1013 Stream stream,
1014 const std::vector<int>& axes,
1015 const std::vector<int>& slice_sizes)
1016 : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {};
1017
1018 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1019 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1020
1024 bool is_equivalent(const Primitive& other) const override;
1025
1026 private:
1027 void eval(const std::vector<array>& inputs, array& out);
1028 std::vector<int> axes_;
1029 std::vector<int> slice_sizes_;
1030};
1031
1032class Greater : public UnaryPrimitive {
1033 public:
1034 explicit Greater(Stream stream) : UnaryPrimitive(stream) {};
1035
1036 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1037 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1038
1044
1045 private:
1046 void eval(const std::vector<array>& inputs, array& out);
1047};
1048
1050 public:
1051 explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {};
1052
1053 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1054 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1055
1061
1062 private:
1063 void eval(const std::vector<array>& inputs, array& out);
1064};
1065
1066class Less : public UnaryPrimitive {
1067 public:
1068 explicit Less(Stream stream) : UnaryPrimitive(stream) {};
1069
1070 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1071 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1072
1078
1079 private:
1080 void eval(const std::vector<array>& inputs, array& out);
1081};
1082
1084 public:
1085 explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {};
1086
1087 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1088 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1089
1095
1096 private:
1097 void eval(const std::vector<array>& inputs, array& out);
1098};
1099
1100class Load : public UnaryPrimitive {
1101 public:
1102 explicit Load(
1103 Stream stream,
1104 std::shared_ptr<io::Reader> reader,
1105 size_t offset,
1106 bool swap_endianness = false)
1107 : UnaryPrimitive(stream),
1108 reader_(reader),
1109 offset_(offset),
1110 swap_endianness_(swap_endianness) {};
1111
1112 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1113 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1114
1116
1117 private:
1118 void eval(const std::vector<array>& inputs, array& out);
1119 std::shared_ptr<io::Reader> reader_;
1120 size_t offset_;
1121 bool swap_endianness_;
1122};
1123
1124class Log : public UnaryPrimitive {
1125 public:
1126 enum Base { two, ten, e };
1127
1128 explicit Log(Stream stream, Base base)
1129 : UnaryPrimitive(stream), base_(base) {};
1130
1131 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1132 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1133
1138
1139 void print(std::ostream& os) override {
1140 switch (base_) {
1141 case e:
1142 os << "Log";
1143 break;
1144 case two:
1145 os << "Log2";
1146 break;
1147 case ten:
1148 os << "Log10";
1149 break;
1150 }
1151 }
1152
1153 private:
1154 Base base_;
1155 void eval(const std::vector<array>& inputs, array& out);
1156};
1157
1158class Log1p : public UnaryPrimitive {
1159 public:
1160 explicit Log1p(Stream stream) : UnaryPrimitive(stream) {};
1161
1162 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1163 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1164
1169
1170 private:
1171 void eval(const std::vector<array>& inputs, array& out);
1172};
1173
1175 public:
1176 explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {};
1177
1178 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1179 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1180
1186
1187 private:
1188 void eval(const std::vector<array>& inputs, array& out);
1189};
1190
1192 public:
1193 explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {};
1194
1195 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1196 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1197
1203
1204 private:
1205 void eval(const std::vector<array>& inputs, array& out);
1206};
1207
1209 public:
1210 explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {};
1211
1212 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1213 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1214
1220
1221 private:
1222 void eval(const std::vector<array>& inputs, array& out);
1223};
1224
1226 public:
1227 explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {};
1228
1229 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1230 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1231
1237
1238 private:
1239 void eval(const std::vector<array>& inputs, array& out);
1240};
1241
1242class Matmul : public UnaryPrimitive {
1243 public:
1244 explicit Matmul(Stream stream) : UnaryPrimitive(stream) {};
1245
1246 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1247 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1248
1249 std::vector<array> vjp(
1250 const std::vector<array>& primals,
1251 const std::vector<array>& cotangents,
1252 const std::vector<int>& argnums,
1253 const std::vector<array>& outputs) override;
1254
1258};
1259
1260class Maximum : public UnaryPrimitive {
1261 public:
1262 explicit Maximum(Stream stream) : UnaryPrimitive(stream) {};
1263
1264 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1265 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1266
1272
1273 private:
1274 void eval(const std::vector<array>& inputs, array& out);
1275};
1276
1277class Minimum : public UnaryPrimitive {
1278 public:
1279 explicit Minimum(Stream stream) : UnaryPrimitive(stream) {};
1280
1281 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1282 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1283
1289
1290 private:
1291 void eval(const std::vector<array>& inputs, array& out);
1292};
1293
1294class Multiply : public UnaryPrimitive {
1295 public:
1296 explicit Multiply(Stream stream) : UnaryPrimitive(stream) {};
1297
1298 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1299 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1300
1306
1307 private:
1308 void eval(const std::vector<array>& inputs, array& out);
1309};
1310
1311class Negative : public UnaryPrimitive {
1312 public:
1313 explicit Negative(Stream stream) : UnaryPrimitive(stream) {};
1314
1315 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1316 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1317
1323
1324 private:
1325 void eval(const std::vector<array>& inputs, array& out);
1326};
1327
1328class NotEqual : public UnaryPrimitive {
1329 public:
1330 explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {};
1331
1332 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1333 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1334
1340
1341 private:
1342 void eval(const std::vector<array>& inputs, array& out);
1343};
1344
1346 public:
1348 Stream stream,
1349 std::vector<int> axes,
1350 bool inverted,
1351 Dtype dtype)
1352 : UnaryPrimitive(stream),
1353 axes_(std::move(axes)),
1354 inverted_(inverted),
1355 dtype_(dtype) {}
1356
1357 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1358 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1359
1362 bool is_equivalent(const Primitive& other) const override;
1363 std::vector<std::vector<int>> output_shapes(
1364 const std::vector<array>& inputs) override {
1365 return {{}};
1366 }
1367
1368 private:
1369 std::vector<int> axes_;
1370 bool inverted_;
1371 Dtype dtype_;
1372
1373 void eval(const std::vector<array>& inputs, array& out);
1374};
1375
1376class Pad : public UnaryPrimitive {
1377 public:
1378 explicit Pad(
1379 Stream stream,
1380 const std::vector<int>& axes,
1381 const std::vector<int>& low_pad_size,
1382 const std::vector<int>& high_pad_size)
1383 : UnaryPrimitive(stream),
1384 axes_(axes),
1385 low_pad_size_(low_pad_size),
1386 high_pad_size_(high_pad_size) {};
1387
1388 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1389 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1390
1394 bool is_equivalent(const Primitive& other) const override;
1395
1396 private:
1397 std::vector<int> axes_;
1398 std::vector<int> low_pad_size_;
1399 std::vector<int> high_pad_size_;
1400
1401 void eval(const std::vector<array>& inputs, array& out);
1402};
1403
1405 public:
1406 explicit Partition(Stream stream, int kth, int axis)
1407 : UnaryPrimitive(stream), kth_(kth), axis_(axis) {};
1408
1409 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1410 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1411
1416 bool is_equivalent(const Primitive& other) const override;
1417
1418 private:
1419 int kth_;
1420 int axis_;
1421
1422 void eval(const std::vector<array>& inputs, array& out);
1423};
1424
1425class Power : public UnaryPrimitive {
1426 public:
1427 explicit Power(Stream stream) : UnaryPrimitive(stream) {};
1428
1429 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1430 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1431
1437
1438 private:
1439 void eval(const std::vector<array>& inputs, array& out);
1440};
1441
1443 public:
1445 Stream stream,
1446 int group_size,
1447 int bits,
1448 bool transpose)
1449 : UnaryPrimitive(stream),
1450 group_size_(group_size),
1451 bits_(bits),
1452 transpose_(transpose) {};
1453
1454 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1455 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1456
1460 bool is_equivalent(const Primitive& other) const override;
1461
1462 private:
1463 int group_size_;
1464 int bits_;
1465 bool transpose_;
1466
1467 void eval(const std::vector<array>& inputs, array& out);
1468};
1469
1471 public:
1472 explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
1473 : UnaryPrimitive(stream),
1474 group_size_(group_size),
1475 bits_(bits),
1476 transpose_(transpose) {};
1477
1478 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1479 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1480
1484 bool is_equivalent(const Primitive& other) const override;
1485
1486 private:
1487 int group_size_;
1488 int bits_;
1489 bool transpose_;
1490
1491 void eval(const std::vector<array>& inputs, array& out);
1492};
1493
1495 public:
1496 explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
1497 : UnaryPrimitive(stream), shape_(shape), width_(width) {};
1498
1499 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1500 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1501
1504 bool is_equivalent(const Primitive& other) const override;
1505
1506 private:
1507 std::vector<int> shape_;
1508 int width_;
1509
1510 void eval(const std::vector<array>& inputs, array& out);
1511};
1512
1513class Reshape : public UnaryPrimitive {
1514 public:
1515 explicit Reshape(Stream stream, const std::vector<int>& shape)
1516 : UnaryPrimitive(stream), shape_(shape) {};
1517
1518 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1519 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1520
1524 bool is_equivalent(const Primitive& other) const override;
1525
1526 private:
1527 std::vector<int> shape_;
1528
1529 void eval(const std::vector<array>& inputs, array& out);
1530
1531 std::pair<bool, std::vector<size_t>> prepare_reshape(
1532 const array& in,
1533 const array& out);
1534 void shared_buffer_reshape(
1535 const array& in,
1536 const std::vector<size_t>& out_strides,
1537 array& out);
1538};
1539
1540class Reduce : public UnaryPrimitive {
1541 public:
1542 enum ReduceType { And, Or, Sum, Prod, Min, Max };
1543
1544 explicit Reduce(
1545 Stream stream,
1546 ReduceType reduce_type,
1547 const std::vector<int>& axes)
1548 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
1549
1550 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1551 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1552
1554
1555 std::vector<array> vjp(
1556 const std::vector<array>& primals,
1557 const std::vector<array>& cotangents,
1558 const std::vector<int>& argnums,
1559 const std::vector<array>& outputs) override;
1560
1561 std::vector<std::vector<int>> output_shapes(
1562 const std::vector<array>& inputs) override;
1563
1564 void print(std::ostream& os) override {
1565 switch (reduce_type_) {
1566 case And:
1567 os << "And";
1568 break;
1569 case Or:
1570 os << "Or";
1571 break;
1572 case Sum:
1573 os << "Sum";
1574 break;
1575 case Prod:
1576 os << "Prod";
1577 break;
1578 case Min:
1579 os << "Min";
1580 break;
1581 case Max:
1582 os << "Max";
1583 break;
1584 }
1585 }
1586 bool is_equivalent(const Primitive& other) const override;
1587
1588 private:
1589 ReduceType reduce_type_;
1590 std::vector<int> axes_;
1591
1592 void eval(const std::vector<array>& inputs, array& out);
1593};
1594
1595class Round : public UnaryPrimitive {
1596 public:
1597 explicit Round(Stream stream) : UnaryPrimitive(stream) {};
1598
1599 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1600 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1601
1607
1608 private:
1609 void eval(const std::vector<array>& inputs, array& out);
1610};
1611
1612class Scan : public UnaryPrimitive {
1613 public:
1615
1616 explicit Scan(
1617 Stream stream,
1618 ReduceType reduce_type,
1619 int axis,
1620 bool reverse,
1621 bool inclusive)
1622 : UnaryPrimitive(stream),
1623 reduce_type_(reduce_type),
1624 axis_(axis),
1625 reverse_(reverse),
1626 inclusive_(inclusive) {};
1627
1628 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1629 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1630
1633
1634 void print(std::ostream& os) override {
1635 os << "Cum";
1636 switch (reduce_type_) {
1637 case Sum:
1638 os << "Sum";
1639 break;
1640 case Prod:
1641 os << "Prod";
1642 break;
1643 case Min:
1644 os << "Min";
1645 break;
1646 case Max:
1647 os << "Max";
1648 break;
1649 }
1650 }
1651 bool is_equivalent(const Primitive& other) const override;
1652
1653 private:
1654 ReduceType reduce_type_;
1655 int axis_;
1656 bool reverse_;
1657 bool inclusive_;
1658
1659 void eval(const std::vector<array>& inputs, array& out);
1660};
1661
1662class Scatter : public UnaryPrimitive {
1663 public:
1665
1666 explicit Scatter(
1667 Stream stream,
1668 ReduceType reduce_type,
1669 const std::vector<int>& axes)
1670 : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {};
1671
1672 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1673 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1674
1676 void print(std::ostream& os) override {
1677 os << "Scatter";
1678 switch (reduce_type_) {
1679 case Sum:
1680 os << " Sum";
1681 break;
1682 case Prod:
1683 os << " Prod";
1684 break;
1685 case Min:
1686 os << " Min";
1687 break;
1688 case Max:
1689 os << " Max";
1690 break;
1691 case None:
1692 break;
1693 }
1694 }
1695 bool is_equivalent(const Primitive& other) const override;
1696
1697 private:
1698 void eval(const std::vector<array>& inputs, array& out);
1699 ReduceType reduce_type_;
1700 std::vector<int> axes_;
1701};
1702
1703class Sigmoid : public UnaryPrimitive {
1704 public:
1705 explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {};
1706
1707 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1708 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1709
1715
1716 private:
1717 void eval(const std::vector<array>& inputs, array& out);
1718};
1719
1720class Sign : public UnaryPrimitive {
1721 public:
1722 explicit Sign(Stream stream) : UnaryPrimitive(stream) {};
1723
1724 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1725 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1726
1732
1733 private:
1734 void eval(const std::vector<array>& inputs, array& out);
1735};
1736
1737class Sin : public UnaryPrimitive {
1738 public:
1739 explicit Sin(Stream stream) : UnaryPrimitive(stream) {};
1740
1741 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1742 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1743
1749
1750 private:
1751 void eval(const std::vector<array>& inputs, array& out);
1752};
1753
1754class Sinh : public UnaryPrimitive {
1755 public:
1756 explicit Sinh(Stream stream) : UnaryPrimitive(stream) {};
1757
1758 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1759 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1760
1766
1767 private:
1768 void eval(const std::vector<array>& inputs, array& out);
1769};
1770
1771class Slice : public UnaryPrimitive {
1772 public:
1773 explicit Slice(
1774 Stream stream,
1775 const std::vector<int>& start_indices,
1776 const std::vector<int>& end_indices,
1777 const std::vector<int>& strides)
1778 : UnaryPrimitive(stream),
1779 start_indices_(start_indices),
1780 end_indices_(end_indices),
1781 strides_(strides) {};
1782
1783 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1784 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1785
1789 bool is_equivalent(const Primitive& other) const override;
1790
1791 private:
1792 std::vector<int> start_indices_;
1793 std::vector<int> end_indices_;
1794 std::vector<int> strides_;
1795
1796 void eval(const std::vector<array>& inputs, array& out);
1797};
1798
1800 public:
1801 explicit SliceUpdate(
1802 Stream stream,
1803 const std::vector<int>& start_indices,
1804 const std::vector<int>& end_indices,
1805 const std::vector<int>& strides)
1806 : UnaryPrimitive(stream),
1807 start_indices_(start_indices),
1808 end_indices_(end_indices),
1809 strides_(strides) {};
1810
1811 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1812 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1813
1817 bool is_equivalent(const Primitive& other) const override;
1818
1819 private:
1820 std::vector<int> start_indices_;
1821 std::vector<int> end_indices_;
1822 std::vector<int> strides_;
1823
1824 void eval(const std::vector<array>& inputs, array& out);
1825
1826 std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
1827};
1828
1829class Softmax : public UnaryPrimitive {
1830 public:
1831 explicit Softmax(Stream stream, bool precise)
1832 : UnaryPrimitive(stream), precise_(precise) {};
1833
1834 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1835 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1836
1841
1842 bool is_equivalent(const Primitive& other) const override;
1843
1844 private:
1845 void eval(const std::vector<array>& inputs, array& out);
1846 bool precise_;
1847};
1848
1849class Sort : public UnaryPrimitive {
1850 public:
1851 explicit Sort(Stream stream, int axis)
1852 : UnaryPrimitive(stream), axis_(axis) {};
1853
1854 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1855 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1856
1861 bool is_equivalent(const Primitive& other) const override;
1862
1863 private:
1864 int axis_;
1865
1866 void eval(const std::vector<array>& inputs, array& out);
1867};
1868
1869class Split : public Primitive {
1870 public:
1871 explicit Split(Stream stream, const std::vector<int>& indices, int axis)
1872 : Primitive(stream), indices_(indices), axis_(axis) {};
1873
1874 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1875 override;
1876 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
1877 override;
1878
1882 bool is_equivalent(const Primitive& other) const override;
1883
1884 private:
1885 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
1886
1887 std::vector<int> indices_;
1888 int axis_;
1889};
1890
1891class Square : public UnaryPrimitive {
1892 public:
1893 explicit Square(Stream stream) : UnaryPrimitive(stream) {};
1894
1895 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1896 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1897
1903
1904 private:
1905 void eval(const std::vector<array>& inputs, array& out);
1906};
1907
1908class Sqrt : public UnaryPrimitive {
1909 public:
1910 explicit Sqrt(Stream stream, bool recip = false)
1911 : UnaryPrimitive(stream), recip_(recip) {};
1912
1913 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1914 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1915
1919 bool is_equivalent(const Primitive& other) const override;
1920
1921 void print(std::ostream& os) override {
1922 if (recip_) {
1923 os << "Rsqrt";
1924 } else {
1925 os << "Sqrt";
1926 }
1927 }
1928
1929 private:
1930 void eval(const std::vector<array>& inputs, array& out);
1931 bool recip_;
1932};
1933
1935 public:
1936 explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {};
1937
1938 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1939 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1940
1945
1946 private:
1947 void eval(const std::vector<array>& inputs, array& out);
1948};
1949
1950class Subtract : public UnaryPrimitive {
1951 public:
1952 explicit Subtract(Stream stream) : UnaryPrimitive(stream) {};
1953
1954 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1955 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1956
1962
1963 private:
1964 void eval(const std::vector<array>& inputs, array& out);
1965};
1966
1967class Tan : public UnaryPrimitive {
1968 public:
1969 explicit Tan(Stream stream) : UnaryPrimitive(stream) {};
1970
1971 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1972 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1973
1979
1980 private:
1981 void eval(const std::vector<array>& inputs, array& out);
1982};
1983
1984class Tanh : public UnaryPrimitive {
1985 public:
1986 explicit Tanh(Stream stream) : UnaryPrimitive(stream) {};
1987
1988 void eval_cpu(const std::vector<array>& inputs, array& out) override;
1989 void eval_gpu(const std::vector<array>& inputs, array& out) override;
1990
1996
1997 private:
1998 void eval(const std::vector<array>& inputs, array& out);
1999};
2000
2001class Uniform : public UnaryPrimitive {
2002 public:
2003 explicit Uniform(Stream stream) : UnaryPrimitive(stream) {};
2004
2005 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2006 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2007
2011
2012 private:
2013 void eval(const std::vector<array>& inputs, array& out);
2014};
2015
2016class View : public UnaryPrimitive {
2017 public:
2018 explicit View(Stream stream, Dtype dtype)
2019 : UnaryPrimitive(stream), dtype_(dtype) {};
2020
2021 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2022 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2023
2025 void print(std::ostream& os) override;
2026 bool is_equivalent(const Primitive& other) const override;
2027
2028 private:
2029 Dtype dtype_;
2030};
2031
2033 public:
2034 explicit Transpose(Stream stream, const std::vector<int>& axes)
2035 : UnaryPrimitive(stream), axes_(axes) {};
2036
2037 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2038 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2039
2043 bool is_equivalent(const Primitive& other) const override;
2044
2045 private:
2046 std::vector<int> axes_;
2047
2048 void eval(const std::vector<array>& inputs, array& out);
2049};
2050
2051/* QR Factorization primitive. */
2052class QRF : public Primitive {
2053 public:
2054 explicit QRF(Stream stream) : Primitive(stream) {};
2055
2056 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2057 override;
2058 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2059 override;
2060
2062
2063 private:
2064 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2065};
2066
2067/* SVD primitive. */
2068class SVD : public Primitive {
2069 public:
2070 explicit SVD(Stream stream) : Primitive(stream) {};
2071
2072 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2073 override;
2074 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2075 override;
2076
2079
2080 private:
2081 void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2082};
2083
2084/* Matrix inversion primitive. */
2085class Inverse : public UnaryPrimitive {
2086 public:
2087 explicit Inverse(Stream stream) : UnaryPrimitive(stream) {};
2088
2089 void eval_cpu(const std::vector<array>& inputs, array& output) override;
2090 void eval_gpu(const std::vector<array>& inputs, array& output) override;
2091
2094
2095 private:
2096 void eval(const std::vector<array>& inputs, array& output);
2097};
2098
2099class Cholesky : public UnaryPrimitive {
2100 public:
2101 explicit Cholesky(Stream stream, bool upper)
2102 : UnaryPrimitive(stream), upper_(upper) {};
2103
2104 void eval_cpu(const std::vector<array>& inputs, array& out) override;
2105 void eval_gpu(const std::vector<array>& inputs, array& out) override;
2106
2109
2110 private:
2111 void eval(const std::vector<array>& inputs, array& output);
2112 bool upper_;
2113};
2114
2115} // namespace mlx::core
Definition primitives.h:155
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Abs(Stream stream)
Definition primitives.h:157
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:164
std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs) override
Get the output shapes of the primitive.
Definition primitives.h:166
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:165
Definition primitives.h:172
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Add(Stream stream)
Definition primitives.h:174
Definition primitives.h:189
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
AddMM(Stream stream, float alpha, float beta)
Definition primitives.h:191
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:213
Arange(Stream stream, double start, double stop, double step)
Definition primitives.h:215
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:232
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCos(Stream stream)
Definition primitives.h:234
Definition primitives.h:249
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcCosh(Stream stream)
Definition primitives.h:251
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:266
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcSin(Stream stream)
Definition primitives.h:268
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:283
ArcSinh(Stream stream)
Definition primitives.h:285
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:317
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTan2(Stream stream)
Definition primitives.h:319
Definition primitives.h:300
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArcTan(Stream stream)
Definition primitives.h:302
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:334
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArcTanh(Stream stream)
Definition primitives.h:336
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:351
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
ArgPartition(Stream stream, int kth, int axis)
Definition primitives.h:353
Definition primitives.h:371
ReduceType
Definition primitives.h:373
@ ArgMin
Definition primitives.h:374
@ ArgMax
Definition primitives.h:375
ArgReduce(Stream stream, ReduceType reduce_type, int axis)
Definition primitives.h:378
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:397
void eval_cpu(const std::vector< array > &inputs, array &out) override
ArgSort(Stream stream, int axis)
Definition primitives.h:399
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:436
AsStrided(Stream stream, std::vector< int > shape, std::vector< size_t > strides, size_t offset)
Definition primitives.h:438
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:416
void eval_gpu(const std::vector< array > &inputs, array &out) override
AsType(Stream stream, Dtype dtype)
Definition primitives.h:418
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:463
BitwiseBinary(Stream stream, Op op)
Definition primitives.h:467
void eval_cpu(const std::vector< array > &inputs, array &out) override
Op
Definition primitives.h:465
@ And
Definition primitives.h:465
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:482
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
BlockMaskedMM(Stream stream, int block_size)
Definition primitives.h:484
Definition primitives.h:525
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Broadcast(Stream stream, const std::vector< int > &shape)
Definition primitives.h:527
Definition primitives.h:544
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Ceil(Stream stream)
Definition primitives.h:546
Definition primitives.h:2099
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cholesky(Stream stream, bool upper)
Definition primitives.h:2101
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:561
Compiled(Stream stream, std::vector< array > inputs, std::vector< array > outputs, std::vector< array > tape, std::unordered_set< uintptr_t > constant_ids)
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:604
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Concatenate(Stream stream, int axis)
Definition primitives.h:606
Definition primitives.h:623
Conjugate(Stream stream)
Definition primitives.h:625
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:639
void eval_gpu(const std::vector< array > &inputs, array &out) override
Convolution(Stream stream, const std::vector< int > &kernel_strides, const std::vector< int > &padding, const std::vector< int > &kernel_dilation, const std::vector< int > &input_dilation, const int groups=1, const bool flip=false)
Definition primitives.h:641
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
Definition primitives.h:680
void eval_gpu(const std::vector< array > &inputs, array &out) override
Copy(Stream stream)
Definition primitives.h:682
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:697
void eval_cpu(const std::vector< array > &inputs, array &out) override
Cos(Stream stream)
Definition primitives.h:699
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:714
void eval_gpu(const std::vector< array > &inputs, array &out) override
Cosh(Stream stream)
Definition primitives.h:716
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:731
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
CustomVJP(Stream stream, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun)
Definition primitives.h:733
Definition primitives.h:764
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotan, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Depends(Stream stream)
Definition primitives.h:766
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:802
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
DivMod(Stream stream)
Definition primitives.h:804
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:785
Divide(Stream stream)
Definition primitives.h:787
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:858
Equal(Stream stream, bool equal_nan=false)
Definition primitives.h:860
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:884
Erf(Stream stream)
Definition primitives.h:886
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:901
void eval_gpu(const std::vector< array > &inputs, array &out) override
ErfInv(Stream stream)
Definition primitives.h:903
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:918
Exp(Stream stream)
Definition primitives.h:920
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:935
Expm1(Stream stream)
Definition primitives.h:937
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:951
FFT(Stream stream, const std::vector< size_t > &axes, bool inverse, bool real)
Definition primitives.h:953
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:977
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Floor(Stream stream)
Definition primitives.h:979
Definition primitives.h:994
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Full(Stream stream)
Definition primitives.h:996
Definition primitives.h:1010
Gather(Stream stream, const std::vector< int > &axes, const std::vector< int > &slice_sizes)
Definition primitives.h:1012
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:505
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
GatherMM(Stream stream)
Definition primitives.h:507
Definition primitives.h:1470
GatherQMM(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1472
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1049
void eval_cpu(const std::vector< array > &inputs, array &out) override
GreaterEqual(Stream stream)
Definition primitives.h:1051
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1032
Greater(Stream stream)
Definition primitives.h:1034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2085
void eval_gpu(const std::vector< array > &inputs, array &output) override
Inverse(Stream stream)
Definition primitives.h:2087
void eval_cpu(const std::vector< array > &inputs, array &output) override
Definition primitives.h:1083
LessEqual(Stream stream)
Definition primitives.h:1085
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1066
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Less(Stream stream)
Definition primitives.h:1068
Definition primitives.h:1100
void eval_gpu(const std::vector< array > &inputs, array &out) override
Load(Stream stream, std::shared_ptr< io::Reader > reader, size_t offset, bool swap_endianness=false)
Definition primitives.h:1102
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1158
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Log1p(Stream stream)
Definition primitives.h:1160
Definition primitives.h:1225
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogAddExp(Stream stream)
Definition primitives.h:1227
Definition primitives.h:1124
Base
Definition primitives.h:1126
Log(Stream stream, Base base)
Definition primitives.h:1128
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1191
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalAnd(Stream stream)
Definition primitives.h:1193
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1174
void eval_gpu(const std::vector< array > &inputs, array &out) override
LogicalNot(Stream stream)
Definition primitives.h:1176
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1208
void eval_cpu(const std::vector< array > &inputs, array &out) override
LogicalOr(Stream stream)
Definition primitives.h:1210
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1242
void eval_cpu(const std::vector< array > &inputs, array &out) override
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_gpu(const std::vector< array > &inputs, array &out) override
Matmul(Stream stream)
Definition primitives.h:1244
Definition primitives.h:1260
Maximum(Stream stream)
Definition primitives.h:1262
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1277
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Minimum(Stream stream)
Definition primitives.h:1279
Definition primitives.h:1294
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Multiply(Stream stream)
Definition primitives.h:1296
Definition primitives.h:1311
void eval_gpu(const std::vector< array > &inputs, array &out) override
Negative(Stream stream)
Definition primitives.h:1313
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1328
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
NotEqual(Stream stream)
Definition primitives.h:1330
Definition primitives.h:1345
void eval_gpu(const std::vector< array > &inputs, array &out) override
NumberOfElements(Stream stream, std::vector< int > axes, bool inverted, Dtype dtype)
Definition primitives.h:1347
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1376
void eval_cpu(const std::vector< array > &inputs, array &out) override
Pad(Stream stream, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size)
Definition primitives.h:1378
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1404
void eval_cpu(const std::vector< array > &inputs, array &out) override
Partition(Stream stream, int kth, int axis)
Definition primitives.h:1406
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1425
void eval_cpu(const std::vector< array > &inputs, array &out) override
Power(Stream stream)
Definition primitives.h:1427
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:48
virtual void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
virtual std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs)
The vector-Jacobian product.
virtual ~Primitive()=default
Primitive(const Primitive &other)=delete
Primitive(Primitive &&other)=delete
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive & operator=(Primitive &&other)=delete
virtual bool is_equivalent(const Primitive &other) const
Equivalence check defaults to false unless overridden by the primitive.
Definition primitives.h:107
Primitive & operator=(const Primitive &other)=delete
virtual std::vector< std::vector< int > > output_shapes(const std::vector< array > &inputs)
Get the output shapes of the primitive.
const Device & device()
The device the primitive will run on.
Definition primitives.h:53
virtual std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums)
The Jacobian-vector product.
virtual std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes)
The primitive must know how to vectorize itself across the given axes.
virtual void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs)=0
virtual void print(std::ostream &os)=0
Print the primitive.
Primitive(Stream stream)
Definition primitives.h:50
Definition primitives.h:2052
QRF(Stream stream)
Definition primitives.h:2054
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:1442
void eval_gpu(const std::vector< array > &inputs, array &out) override
QuantizedMatmul(Stream stream, int group_size, int bits, bool transpose)
Definition primitives.h:1444
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1494
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
RandomBits(Stream stream, const std::vector< int > &shape, int width)
Definition primitives.h:1496
Definition primitives.h:1540
Reduce(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1544
ReduceType
Definition primitives.h:1542
@ And
Definition primitives.h:1542
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:841
Remainder(Stream stream)
Definition primitives.h:843
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1513
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Reshape(Stream stream, const std::vector< int > &shape)
Definition primitives.h:1515
Definition primitives.h:1595
Round(Stream stream)
Definition primitives.h:1597
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2068
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
SVD(Stream stream)
Definition primitives.h:2070
Definition primitives.h:1612
void eval_cpu(const std::vector< array > &inputs, array &out) override
ReduceType
Definition primitives.h:1614
@ Max
Definition primitives.h:1614
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
Scan(Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive)
Definition primitives.h:1616
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1662
bool is_equivalent(const Primitive &other) const override
Equivalence check defaults to false unless overridden by the primitive.
ReduceType
Definition primitives.h:1664
@ Max
Definition primitives.h:1664
void eval_cpu(const std::vector< array > &inputs, array &out) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:1676
void eval_gpu(const std::vector< array > &inputs, array &out) override
Scatter(Stream stream, ReduceType reduce_type, const std::vector< int > &axes)
Definition primitives.h:1666
Definition primitives.h:824
void eval_gpu(const std::vector< array > &inputs, array &out) override
Select(Stream stream)
Definition primitives.h:826
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1703
Sigmoid(Stream stream)
Definition primitives.h:1705
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1720
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Sign(Stream stream)
Definition primitives.h:1722
Definition primitives.h:1737
Sin(Stream stream)
Definition primitives.h:1739
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1754
Sinh(Stream stream)
Definition primitives.h:1756
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1771
void eval_cpu(const std::vector< array > &inputs, array &out) override
Slice(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1773
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1799
SliceUpdate(Stream stream, const std::vector< int > &start_indices, const std::vector< int > &end_indices, const std::vector< int > &strides)
Definition primitives.h:1801
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1829
void eval_gpu(const std::vector< array > &inputs, array &out) override
Softmax(Stream stream, bool precise)
Definition primitives.h:1831
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1849
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sort(Stream stream, int axis)
Definition primitives.h:1851
Definition primitives.h:1869
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Split(Stream stream, const std::vector< int > &indices, int axis)
Definition primitives.h:1871
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:1908
void eval_cpu(const std::vector< array > &inputs, array &out) override
Sqrt(Stream stream, bool recip=false)
Definition primitives.h:1910
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1891
void eval_gpu(const std::vector< array > &inputs, array &out) override
void eval_cpu(const std::vector< array > &inputs, array &out) override
Square(Stream stream)
Definition primitives.h:1893
Definition primitives.h:1934
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
StopGradient(Stream stream)
Definition primitives.h:1936
Definition primitives.h:1950
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Subtract(Stream stream)
Definition primitives.h:1952
Definition primitives.h:1967
Tan(Stream stream)
Definition primitives.h:1969
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:1984
void eval_gpu(const std::vector< array > &inputs, array &out) override
Tanh(Stream stream)
Definition primitives.h:1986
void eval_cpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:2032
Transpose(Stream stream, const std::vector< int > &axes)
Definition primitives.h:2034
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition primitives.h:127
UnaryPrimitive & operator=(const UnaryPrimitive &other)=delete
UnaryPrimitive(Stream stream)
An abstract base class for a primitive with a single output.
Definition primitives.h:132
virtual void eval_gpu(const std::vector< array > &inputs, array &output)=0
UnaryPrimitive(UnaryPrimitive &&other)=delete
virtual void eval_cpu(const std::vector< array > &inputs, array &output)=0
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:142
UnaryPrimitive(const UnaryPrimitive &other)=delete
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition primitives.h:137
UnaryPrimitive & operator=(UnaryPrimitive &&other)=delete
virtual ~UnaryPrimitive()=default
Definition primitives.h:2001
void eval_cpu(const std::vector< array > &inputs, array &out) override
void eval_gpu(const std::vector< array > &inputs, array &out) override
Uniform(Stream stream)
Definition primitives.h:2003
Definition primitives.h:2016
void eval_cpu(const std::vector< array > &inputs, array &out) override
View(Stream stream, Dtype dtype)
Definition primitives.h:2018
void eval_gpu(const std::vector< array > &inputs, array &out) override
Definition array.h:20
Op op
Definition binary.h:141
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
Definition allocator.h:7
void eval(std::vector< array > outputs)
#define DEFINE_DEFAULT_IS_EQUIVALENT()
Definition primitives.h:34
#define DEFINE_PRINT(PRIMITIVE)
Definition primitives.h:29
#define DEFINE_INPUT_OUTPUT_SHAPE()
Definition primitives.h:39
#define DEFINE_GRADS()
Definition primitives.h:17
#define DEFINE_VMAP()
Definition primitives.h:12
Definition ops.h:23
Definition binary_ops.h:270
Definition ops.h:159
Definition ops.h:139
Definition ops.h:15
Definition ops.h:61
Definition ops.h:119
Definition binary_ops.h:277
Definition ops.h:99
Definition device.h:7
Definition dtype.h:15
Definition stream.h:9
Device device
Definition stream.h:11