Browse Source

!14888 add testcases related to if, fix codedex check

From: @huangbingjian
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
pull/14888/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
1d2e37be88
13 changed files with 751 additions and 17 deletions
  1. +22
    -17
      mindspore/ccsrc/frontend/optimizer/opt.cc
  2. +3
    -0
      mindspore/ccsrc/frontend/optimizer/opt.h
  3. +60
    -0
      tests/st/control/inner/test_000_single_if.py
  4. +64
    -0
      tests/st/control/inner/test_010_if_in_if.py
  5. +65
    -0
      tests/st/control/inner/test_100_if_after_if.py
  6. +67
    -0
      tests/st/control/inner/test_110_if_after_if_in_if.py
  7. +66
    -0
      tests/st/control/inner/test_112_if_after_if_in_for.py
  8. +67
    -0
      tests/st/control/inner/test_130_if_after_for_in_if.py
  9. +67
    -0
      tests/st/control/inner/test_131_if_after_for_in_while.py
  10. +67
    -0
      tests/st/control/inner/test_132_if_after_for_in_for.py
  11. +66
    -0
      tests/st/control/inner/test_300_for_after_if.py
  12. +69
    -0
      tests/st/control/inner/test_310_for_after_if_in_if.py
  13. +68
    -0
      tests/st/control/inner/test_330_for_after_for_in_if.py

+ 22
- 17
mindspore/ccsrc/frontend/optimizer/opt.cc View File

@@ -103,8 +103,8 @@ static bool isTraversable(const AnfNodePtr &node) {
return false;
}

static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
const SubstitutionPtr &substitution) {
static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
const SubstitutionPtr &substitution) {
auto manager = optimizer->manager();
bool is_match = substitution->predicate_(node);
if (is_match) {
@@ -126,8 +126,8 @@ static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNod
return nullptr;
}

static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node,
std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo,
bool change, size_t seen) {
if (IsValueNode<FuncGraph>(node)) {
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
}
@@ -238,6 +238,23 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
return changes;
}

void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
const OptimizerPtr &optimizer, size_t space) const {
std::stringstream ss;
ss << std::endl
<< "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name
<< std::endl;
for (size_t i = 0; i < list_.size(); i++) {
auto name = list_[i]->name_;
ss << std::left << std::setw(space + 4) << name << "\t";
for (auto change : status.at(name + std::to_string(i))) {
ss << change << " ";
}
ss << std::endl;
}
MS_LOG(DEBUG) << ss.str();
}

bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
// Add for substitution status counting
size_t space = 0;
@@ -282,19 +299,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con

// Display the status of each substitution
if (optimizer->is_on_debug_) {
std::stringstream ss;
ss << std::endl
<< "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name
<< std::endl;
for (size_t i = 0; i < list_.size(); i++) {
auto name = list_[i]->name_;
ss << std::left << std::setw(space + 4) << name << "\t";
for (auto change : status[name + std::to_string(i)]) {
ss << change << " ";
}
ss << std::endl;
}
MS_LOG(DEBUG) << ss.str();
DisplayStatusOfSubstitution(status, optimizer, space);
}
return changes;
}


+ 3
- 0
mindspore/ccsrc/frontend/optimizer/opt.h View File

@@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>

#include "ir/anf.h"
#include "ir/func_graph.h"
@@ -74,6 +75,8 @@ class SubstitutionList {
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
const OptimizerPtr &optimizer, size_t space) const;

std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once


+ 60
- 0
tests/st/control/inner/test_000_single_if.py View File

@@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_signle_if():
class SignleIfNet(nn.Cell):
def construct(self, x, y):
x += 1
if x < y:
y += x
else:
y -= x
y += 5
return y

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_net = SignleIfNet()
net = GradNet(if_net)
graph_forward_res = if_net(x, y)
graph_backward_res = net(x, y)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_net = SignleIfNet()
net = GradNet(if_net)
pynative_forward_res = if_net(x, y)
pynative_backward_res = net(x, y)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 64
- 0
tests/st/control/inner/test_010_if_in_if.py View File

@@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_in_if():
class IfInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
if self.param_a > self.param_b:
x += 10
if x > self.param_a:
self.param_b += 1
x += self.param_a
return x

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_in_if_net = IfInIfNet()
net = GradNet(if_in_if_net)
graph_forward_res = if_in_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_in_if_net = IfInIfNet()
net = GradNet(if_in_if_net)
pynative_forward_res = if_in_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 65
- 0
tests/st/control/inner/test_100_if_after_if.py View File

@@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_after_if():
class IfAfterIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = x + self.param_b
if self.param_a > self.param_b:
x += 5
self.param_b += 4
if x < self.param_b:
out += self.param_b
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_after_if_net = IfAfterIfNet()
net = GradNet(if_after_if_net)
graph_forward_res = if_after_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_if_net = IfAfterIfNet()
net = GradNet(if_after_if_net)
pynative_forward_res = if_after_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 67
- 0
tests/st/control/inner/test_110_if_after_if_in_if.py View File

@@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_after_if_in_if():
class IfAfterIfInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = x + self.param_b
if self.param_a > self.param_b:
x += 5
if x > self.param_a:
self.param_b += 1
self.param_b += 3
if x < self.param_b:
out += self.param_b
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_after_if_in_if_net = IfAfterIfInIfNet()
net = GradNet(if_after_if_in_if_net)
graph_forward_res = if_after_if_in_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_if_in_if_net = IfAfterIfInIfNet()
net = GradNet(if_after_if_in_if_net)
pynative_forward_res = if_after_if_in_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 66
- 0
tests/st/control/inner/test_112_if_after_if_in_for.py View File

@@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_after_if_in_for():
class IfAfterIfInForNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = x + self.param_b
for _ in range(4):
if out <= 20:
out += self.param_a
self.param_b += 3
if x < self.param_b:
out -= self.param_b
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_after_if_in_for_net = IfAfterIfInForNet()
net = GradNet(if_after_if_in_for_net)
graph_forward_res = if_after_if_in_for_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_if_in_for_net = IfAfterIfInForNet()
net = GradNet(if_after_if_in_for_net)
pynative_forward_res = if_after_if_in_for_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 67
- 0
tests/st/control/inner/test_130_if_after_for_in_if.py View File

@@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_after_for_in_if():
class IfAfterForInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = x + self.param_a
if self.param_a > self.param_b:
for _ in range(4):
self.param_a += 1
self.param_b -= 3
self.param_b += 15
if x < self.param_b:
out -= self.param_b
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_after_for_in_if_net = IfAfterForInIfNet()
net = GradNet(if_after_for_in_if_net)
graph_forward_res = if_after_for_in_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_in_if_net = IfAfterForInIfNet()
net = GradNet(if_after_for_in_if_net)
pynative_forward_res = if_after_for_in_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 67
- 0
tests/st/control/inner/test_131_if_after_for_in_while.py View File

@@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_after_for_in_while():
class IfAfterForInWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(2, mstype.int32), name='b')

def construct(self, x):
out = x + self.param_a
while self.param_a > self.param_b:
self.param_b += 1
for _ in range(4):
self.param_a += 3
self.param_a -= 40
if x > self.param_a:
out += self.param_a * 10
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_after_for_in_while_net = IfAfterForInWhileNet()
net = GradNet(if_after_for_in_while_net)
graph_forward_res = if_after_for_in_while_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_in_while_net = IfAfterForInWhileNet()
net = GradNet(if_after_for_in_while_net)
pynative_forward_res = if_after_for_in_while_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 67
- 0
tests/st/control/inner/test_132_if_after_for_in_for.py View File

@@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_if_after_for_in_for():
class IfAfterForInForNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(2, mstype.int32), name='b')

def construct(self, x):
out = x + self.param_a
for _ in range(0, 10):
x *= 2
for _ in range(0, 5):
self.param_a += 1
x += self.param_b
if self.param_a > self.param_b:
out += x
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_after_for_in_for_net = IfAfterForInForNet()
net = GradNet(if_after_for_in_for_net)
graph_forward_res = if_after_for_in_for_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_in_for_net = IfAfterForInForNet()
net = GradNet(if_after_for_in_for_net)
pynative_forward_res = if_after_for_in_for_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 66
- 0
tests/st/control/inner/test_300_for_after_if.py View File

@@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_for_after_if():
class ForAfterIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = self.param_a
if self.param_a > self.param_b:
x += 3
self.param_b += 1
for _ in range(0, 5):
x += self.param_b
out *= x
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(2, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_after_if_net = ForAfterIfNet()
net = GradNet(for_after_if_net)
graph_forward_res = for_after_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_if_net = ForAfterIfNet()
net = GradNet(for_after_if_net)
pynative_forward_res = for_after_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 69
- 0
tests/st/control/inner/test_310_for_after_if_in_if.py View File

@@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_for_after_if_in_if():
class ForAfterIfInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = self.param_a
if self.param_a > self.param_b:
x += 3
if x > self.param_a:
self.param_b += 4
x += self.param_a
self.param_b += 2
for _ in range(0, 5):
x += self.param_b
out *= x
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(5, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_after_if_in_if_net = ForAfterIfInIfNet()
net = GradNet(for_after_if_in_if_net)
graph_forward_res = for_after_if_in_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_if_in_if_net = ForAfterIfInIfNet()
net = GradNet(for_after_if_in_if_net)
pynative_forward_res = for_after_if_in_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 68
- 0
tests/st/control/inner/test_330_for_after_for_in_if.py View File

@@ -0,0 +1,68 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter

grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")

def test_for_after_for_in_if():
class ForAfterForInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = self.param_a
if self.param_a > self.param_b:
for _ in range(0, 4):
self.param_a += 1
self.param_b -= 3
self.param_b += 10
for _ in range(0, 5):
x += self.param_b
out *= x
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(5, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_after_for_in_if_net = ForAfterForInIfNet()
net = GradNet(for_after_for_in_if_net)
graph_forward_res = for_after_for_in_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_in_if_net = ForAfterForInIfNet()
net = GradNet(for_after_for_in_if_net)
pynative_forward_res = for_after_for_in_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

Loading…
Cancel
Save