|
|
|
@@ -18,6 +18,7 @@ import pytest |
|
|
|
|
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.common.api import _executor |
|
|
|
|
|
|
|
|
|
|
|
@@ -93,3 +94,25 @@ def test_compile_unspported(): |
|
|
|
net = unsupported_method_net() |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
_executor.compile(net, input_me) |
|
|
|
|
|
|
|
|
|
|
|
def test_parser_map_0002(): |
|
|
|
class NetMap0002(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.relu = nn.ReLU() |
|
|
|
self.hypermap = C.Map() |
|
|
|
|
|
|
|
def mul(self, x=2, y=4): |
|
|
|
return x * y |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
if map(self.mul) == 8: |
|
|
|
x = self.relu(x) |
|
|
|
return x |
|
|
|
input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32) |
|
|
|
input_me_x = Tensor(input_np_x) |
|
|
|
|
|
|
|
net = NetMap0002() |
|
|
|
with pytest.raises(TypeError): |
|
|
|
net(input_me_x) |