Add view op (#1179)

* add view primitive

* nit

* fix view
This commit is contained in:
Awni Hannun
2024-06-04 08:05:27 -07:00
committed by GitHub
parent 81def6ac76
commit ea9090bbc4
14 changed files with 202 additions and 11 deletions

View File

@@ -3159,8 +3159,6 @@ void init_ops(nb::module_& m) {
R"pbdoc(
1D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): input array of shape (``N``, ``H``, ``C_in``)
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
@@ -3219,8 +3217,6 @@ void init_ops(nb::module_& m) {
R"pbdoc(
2D convolution over an input with several channels
Note: Only the default ``groups=1`` is currently supported.
Args:
input (array): input array of shape ``(N, H, W, C_in)``
weight (array): weight array of shape ``(C_out, H, W, C_in)``
@@ -3390,11 +3386,6 @@ void init_ops(nb::module_& m) {
R"pbdoc(
General convolution over an input with several channels
.. note::
* Only 1d and 2d convolutions are supported at the moment
* the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
@@ -4356,4 +4347,32 @@ void init_ops(nb::module_& m) {
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
m.def(
"view",
[](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) {
return view(to_array(a), dtype, s);
},
nb::arg(),
"dtype"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
View the array as a different type.
The output shape changes along the last axis if the input array's
type and the input ``dtype`` do not have the same size.
Note: the view op does not imply that the input and output arrays share
their underlying data. The view only gaurantees that the binary
representation of each element (or group of elements) is the same.
Args:
a (array): Input array or scalar.
dtype (Dtype): The data type to change to.
Returns:
array: The array with the new type.
)pbdoc");
}