|
|
|
@@ -15,23 +15,17 @@ |
|
|
|
""" test_framstruct """ |
|
|
|
import pytest |
|
|
|
import numpy as np |
|
|
|
import mindspore as ms |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import context |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.parameter import Parameter, ParameterTuple |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore.nn.wrap.cell_wrapper import WithGradCell, WithLossCell |
|
|
|
from ..ut_filter import non_graph_engine |
|
|
|
from ....mindspore_test_framework.utils.check_gradient import ( |
|
|
|
ms_function, check_jacobian, Tensor, NNGradChecker, |
|
|
|
OperationGradChecker, check_gradient, ScalarGradChecker) |
|
|
|
from ....mindspore_test_framework.utils.bprop_util import bprop |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore.ops._grad.grad_base import bprop_getters |
|
|
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer |
|
|
|
|
|
|
|
@@ -299,22 +293,22 @@ def test_dont_unroll_while(): |
|
|
|
assert res == 3 |
|
|
|
|
|
|
|
class ConvNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(ConvNet, self).__init__() |
|
|
|
out_channel = 16 |
|
|
|
kernel_size = 3 |
|
|
|
self.conv = P.Conv2D(out_channel, |
|
|
|
kernel_size, |
|
|
|
mode=1, |
|
|
|
pad_mode="pad", |
|
|
|
pad=0, |
|
|
|
stride=1, |
|
|
|
dilation=2, |
|
|
|
group=1) |
|
|
|
self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return self.conv(x, self.w) |
|
|
|
def __init__(self): |
|
|
|
super(ConvNet, self).__init__() |
|
|
|
out_channel = 16 |
|
|
|
kernel_size = 3 |
|
|
|
self.conv = P.Conv2D(out_channel, |
|
|
|
kernel_size, |
|
|
|
mode=1, |
|
|
|
pad_mode="pad", |
|
|
|
pad=0, |
|
|
|
stride=1, |
|
|
|
dilation=2, |
|
|
|
group=1) |
|
|
|
self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return self.conv(x, self.w) |
|
|
|
|
|
|
|
conv = ConvNet() |
|
|
|
c1 = Tensor([2], mstype.float32) |
|
|
|
@@ -674,7 +668,7 @@ def grad_refactor_6(a, b): |
|
|
|
|
|
|
|
|
|
|
|
def test_grad_refactor_6(): |
|
|
|
C.grad_all(grad_refactor_6)(3, 2) == (3, 1) |
|
|
|
assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1) |
|
|
|
|
|
|
|
|
|
|
|
def grad_refactor_while(x): |
|
|
|
|