Browse Source

!2631 Add ExpandDims to whitelist

Merge pull request !2631 from amongo/FixExpandDimsOps
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bc30576ac9
3 changed files with 30 additions and 0 deletions
  1. +1
    -0
      mindspore/ccsrc/optimizer/irpass/branch_culling.cc
  2. +2
    -0
      mindspore/ops/operations/control_ops.py
  3. +27
    -0
      tests/ut/python/pipeline/parse/test_fix_bug.py

+ 1
- 0
mindspore/ccsrc/optimizer/irpass/branch_culling.cc View File

@@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimApplyRMSProp, {6, 7, 8}}, {prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}}, {prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}}, {prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}},
{prim::kPrimHistogramSummary, {1}}}); {prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) { for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {


+ 2
- 0
mindspore/ops/operations/control_ops.py View File

@@ -30,6 +30,8 @@ class ControlDepend(Primitive):
tells the engine that the destination operations should depend on the source operation which means the source tells the engine that the destination operations should depend on the source operation which means the source
operations should be executed before the destination. operations should be executed before the destination.


Note:
This operation does not work in `PYNATIVE_MODE`.
Args: Args:
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0. depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.




+ 27
- 0
tests/ut/python/pipeline/parse/test_fix_bug.py View File

@@ -19,6 +19,8 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as ms
from mindspore.common.api import _executor from mindspore.common.api import _executor




@@ -116,3 +118,28 @@ def test_parser_map_0002():
net = NetMap0002() net = NetMap0002()
with pytest.raises(TypeError): with pytest.raises(TypeError):
net(input_me_x) net(input_me_x)


def test_fix_expanddims_loss_scale():
class ControlOneIfOneScaleOneScale(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.ExpandDims()

def construct(self, x, y, data):
if x > y:
out = 1
else:
out = 2
if x > y:
out = self.op(data, out)
else:
out = self.op(data, out)
return out
net = ControlOneIfOneScaleOneScale()
x = Tensor(1, ms.float32)
y = Tensor(0, ms.float32)
input_shape = (1024, 512, 7, 7)
input_data = np.random.randn(*input_shape).astype(np.float32)
net = ControlOneIfOneScaleOneScale()
net(x, y, Tensor(input_data))

Loading…
Cancel
Save