mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Initial implementation
This commit is contained in:
@@ -3688,17 +3688,7 @@ TEST_CASE("test conv1d") {
|
||||
}
|
||||
|
||||
TEST_CASE("test conv2d") {
|
||||
auto in = array(
|
||||
{0.57429284,
|
||||
-0.21628855,
|
||||
-0.18673691,
|
||||
-0.3793517,
|
||||
|
||||
0.3059678,
|
||||
-0.8137168,
|
||||
0.6168841,
|
||||
-0.26912728},
|
||||
{1, 2, 2, 2});
|
||||
array in = zeros({1, 2, 2, 3}, float32);
|
||||
|
||||
std::pair<int, int> kernel{2, 2};
|
||||
std::pair<int, int> stride{1, 1};
|
||||
@@ -3707,15 +3697,7 @@ TEST_CASE("test conv2d") {
|
||||
{
|
||||
int groups = 1;
|
||||
|
||||
auto wt = array(
|
||||
{0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
|
||||
-0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
|
||||
0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
|
||||
0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
|
||||
-0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
|
||||
-0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
|
||||
1.6598022, 0.74204415},
|
||||
{4, 2, 2, 2});
|
||||
array wt = ones({1, 2, 2, 3}, float32);
|
||||
|
||||
auto expected =
|
||||
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||
|
||||
Reference in New Issue
Block a user