|
|
|
@@ -19,7 +19,7 @@ from mindspore.nn import Cell |
|
|
|
from mindspore.ops import operations as P |
|
|
|
import mindspore.ops.composite as C |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) |
|
|
|
|
|
|
|
def test_parser_three_default_mixed_args_subnet(): |
|
|
|
|
|
|
|
@@ -227,3 +227,43 @@ def test_net_vargs_expand(): |
|
|
|
|
|
|
|
net.set_train() |
|
|
|
net(x, y, sens) |
|
|
|
|
|
|
|
|
|
|
|
def test_mixed_precision_const_parameter(): |
|
|
|
class NetLoss(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(NetLoss, self).__init__() |
|
|
|
self.shape = P.Shape() |
|
|
|
self.up_sample1 = P.ResizeBilinear((14, 14)) |
|
|
|
self.up_sample2 = P.ResizeBilinear((28, 28)) |
|
|
|
self.up_sample3 = P.ResizeBilinear((36, 36)) |
|
|
|
def construct(self, x, y, z, *args): |
|
|
|
ret = 0 |
|
|
|
if args[0] == self.shape(z)[2]: |
|
|
|
if args[0] == 14: |
|
|
|
ret = self.up_sample1(y) + x |
|
|
|
elif args[0] == 28: |
|
|
|
ret = self.up_sample2(y) - x |
|
|
|
else: |
|
|
|
ret = x / y |
|
|
|
else: |
|
|
|
ret = x * y |
|
|
|
ret = ret * z |
|
|
|
return ret |
|
|
|
class NetMain(Cell): |
|
|
|
def __init__(self, loss_fn): |
|
|
|
super(NetMain, self).__init__() |
|
|
|
self.loss_fn = loss_fn |
|
|
|
self.shape = P.Shape() |
|
|
|
def construct(self, x, y, z): |
|
|
|
size_x = self.shape(x)[2] |
|
|
|
size_y = self.shape(y)[2] |
|
|
|
ret = self.loss_fn(x, y, z, size_x, size_y) |
|
|
|
return ret |
|
|
|
loss_fn = NetLoss() |
|
|
|
net = NetMain(loss_fn) |
|
|
|
net.add_flags_recursive(fp32=True) |
|
|
|
x = Tensor(np.ones((1, 3, 28, 28), np.float32)) |
|
|
|
y = Tensor(np.ones((1, 3, 14, 14), np.float32)) |
|
|
|
z = Tensor(np.ones((1, 3, 28, 28), np.float32)) |
|
|
|
out = net(x, y, z) |