MLX
Loading...
Searching...
No Matches
mlx
backend
common
arange.h
Go to the documentation of this file.
1
// Copyright © 2023 Apple Inc.
2
3
#pragma once
4
5
#include "
mlx/allocator.h
"
6
#include "
mlx/array.h
"
7
8
namespace
mlx::core
{
9
10
namespace
{
11
12
template
<
typename
T>
13
void
arange
(T start, T
next
, array& out,
size_t
size) {
14
auto
ptr = out.data<T>();
15
auto
step_size =
next
- start;
16
for
(
int
i = 0; i < size; ++i) {
17
ptr[i] = start;
18
start += step_size;
19
}
20
}
21
22
}
// namespace
23
24
void
arange
(
25
const
std::vector<array>& inputs,
26
array
& out,
27
double
start,
28
double
step) {
29
assert(inputs.size() == 0);
30
out.
set_data
(
allocator::malloc_or_wait
(out.
nbytes
()));
31
switch
(out.
dtype
()) {
32
case
bool_
:
33
throw
std::runtime_error(
"Bool type unsupported for arange."
);
34
break
;
35
case
uint8
:
36
arange<uint8_t>
(start, start + step, out, out.
size
());
37
break
;
38
case
uint16
:
39
arange<uint16_t>
(start, start + step, out, out.
size
());
40
break
;
41
case
uint32
:
42
arange<uint32_t>
(start, start + step, out, out.
size
());
43
break
;
44
case
uint64
:
45
arange<uint64_t>
(start, start + step, out, out.
size
());
46
break
;
47
case
int8
:
48
arange<int8_t>
(start, start + step, out, out.
size
());
49
break
;
50
case
int16
:
51
arange<int16_t>
(start, start + step, out, out.
size
());
52
break
;
53
case
int32
:
54
arange<int32_t>
(start, start + step, out, out.
size
());
55
break
;
56
case
int64
:
57
arange<int64_t>
(start, start + step, out, out.
size
());
58
break
;
59
case
float16
:
60
arange<float16_t>
(start, start + step, out, out.
size
());
61
break
;
62
case
float32
:
63
arange<float>
(start, start + step, out, out.
size
());
64
break
;
65
case
bfloat16
:
66
arange<bfloat16_t>
(start, start + step, out, out.
size
());
67
break
;
68
case
complex64
:
69
arange<complex64_t>
(start, start + step, out, out.
size
());
70
break
;
71
}
72
}
73
74
}
// namespace mlx::core
allocator.h
array.h
next
BufferHolder * next
Definition
allocator.h:38
mlx::core::array
Definition
array.h:23
mlx::core::array::nbytes
size_t nbytes() const
The number of bytes in the array.
Definition
array.h:92
mlx::core::array::size
size_t size() const
The number of elements in the array.
Definition
array.h:87
mlx::core::array::dtype
Dtype dtype() const
Get the arrays data type.
Definition
array.h:130
mlx::core::array::set_data
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
mlx::core::allocator::malloc_or_wait
Buffer malloc_or_wait(size_t size)
mlx::core
Definition
allocator.h:7
mlx::core::bool_
constexpr Dtype bool_
Definition
dtype.h:67
mlx::core::uint64
constexpr Dtype uint64
Definition
dtype.h:72
mlx::core::uint16
constexpr Dtype uint16
Definition
dtype.h:70
mlx::core::arange
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition
arange.h:24
mlx::core::bfloat16
constexpr Dtype bfloat16
Definition
dtype.h:81
mlx::core::int32
constexpr Dtype int32
Definition
dtype.h:76
mlx::core::float32
constexpr Dtype float32
Definition
dtype.h:80
mlx::core::int16
constexpr Dtype int16
Definition
dtype.h:75
mlx::core::int8
constexpr Dtype int8
Definition
dtype.h:74
mlx::core::int64
constexpr Dtype int64
Definition
dtype.h:77
mlx::core::uint8
constexpr Dtype uint8
Definition
dtype.h:69
mlx::core::float16
constexpr Dtype float16
Definition
dtype.h:79
mlx::core::uint32
constexpr Dtype uint32
Definition
dtype.h:71
mlx::core::complex64
constexpr Dtype complex64
Definition
dtype.h:82
Generated by
1.12.0