Browse Source

!2631 Add ExpandDims to whitelist

Merge pull request !2631 from amongo/FixExpandDimsOps
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bc30576ac9
3 changed files with 30 additions and 0 deletions
  1. +1
    -0
      mindspore/ccsrc/optimizer/irpass/branch_culling.cc
  2. +2
    -0
      mindspore/ops/operations/control_ops.py
  3. +27
    -0
      tests/ut/python/pipeline/parse/test_fix_bug.py

+ 1
- 0
mindspore/ccsrc/optimizer/irpass/branch_culling.cc View File

@@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}},
{prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {


+ 2
- 0
mindspore/ops/operations/control_ops.py View File

@@ -30,6 +30,8 @@ class ControlDepend(Primitive):
tells the engine that the destination operations should depend on the source operation which means the source
operations should be executed before the destination.

Note:
This operation does not work in `PYNATIVE_MODE`.
Args:
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.



+ 27
- 0
tests/ut/python/pipeline/parse/test_fix_bug.py View File

@@ -19,6 +19,8 @@ import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as ms
from mindspore.common.api import _executor


@@ -116,3 +118,28 @@ def test_parser_map_0002():
net = NetMap0002()
with pytest.raises(TypeError):
net(input_me_x)


def test_fix_expanddims_loss_scale():
class ControlOneIfOneScaleOneScale(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.ExpandDims()

def construct(self, x, y, data):
if x > y:
out = 1
else:
out = 2
if x > y:
out = self.op(data, out)
else:
out = self.op(data, out)
return out
net = ControlOneIfOneScaleOneScale()
x = Tensor(1, ms.float32)
y = Tensor(0, ms.float32)
input_shape = (1024, 512, 7, 7)
input_data = np.random.randn(*input_shape).astype(np.float32)
net = ControlOneIfOneScaleOneScale()
net(x, y, Tensor(input_data))

Loading…
Cancel
Save