MLX
 
Loading...
Searching...
No Matches
arange.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
7
8namespace mlx::core {
9
10namespace {
11
12template <typename T>
13void arange(T start, T next, array& out, size_t size) {
14 auto ptr = out.data<T>();
15 auto step_size = next - start;
16 for (int i = 0; i < size; ++i) {
17 ptr[i] = start;
18 start += step_size;
19 }
20}
21
22} // namespace
23
24void arange(
25 const std::vector<array>& inputs,
26 array& out,
27 double start,
28 double step) {
29 assert(inputs.size() == 0);
31 switch (out.dtype()) {
32 case bool_:
33 throw std::runtime_error("Bool type unsupported for arange.");
34 break;
35 case uint8:
36 arange<uint8_t>(start, start + step, out, out.size());
37 break;
38 case uint16:
39 arange<uint16_t>(start, start + step, out, out.size());
40 break;
41 case uint32:
42 arange<uint32_t>(start, start + step, out, out.size());
43 break;
44 case uint64:
45 arange<uint64_t>(start, start + step, out, out.size());
46 break;
47 case int8:
48 arange<int8_t>(start, start + step, out, out.size());
49 break;
50 case int16:
51 arange<int16_t>(start, start + step, out, out.size());
52 break;
53 case int32:
54 arange<int32_t>(start, start + step, out, out.size());
55 break;
56 case int64:
57 arange<int64_t>(start, start + step, out, out.size());
58 break;
59 case float16:
60 arange<float16_t>(start, start + step, out, out.size());
61 break;
62 case float32:
63 arange<float>(start, start + step, out, out.size());
64 break;
65 case bfloat16:
66 arange<bfloat16_t>(start, start + step, out, out.size());
67 break;
68 case complex64:
69 arange<complex64_t>(start, start + step, out, out.size());
70 break;
71 }
72}
73
74} // namespace mlx::core
Definition array.h:24
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t size() const
The number of elements in the array.
Definition array.h:88
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
constexpr Dtype complex64
Definition dtype.h:82