Merge pull request !2631 from amongo/FixExpandDimsOpstags/v0.6.0-beta
| @@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { | |||||
| {prim::kPrimApplyRMSProp, {6, 7, 8}}, | {prim::kPrimApplyRMSProp, {6, 7, 8}}, | ||||
| {prim::kPrimCumSum, {2}}, | {prim::kPrimCumSum, {2}}, | ||||
| {prim::kPrimTile, {2}}, | {prim::kPrimTile, {2}}, | ||||
| {prim::kPrimExpandDims, {2}}, | |||||
| {prim::kPrimHistogramSummary, {1}}}); | {prim::kPrimHistogramSummary, {1}}}); | ||||
| for (auto &item : white_list) { | for (auto &item : white_list) { | ||||
| auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { | auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { | ||||
| @@ -30,6 +30,8 @@ class ControlDepend(Primitive): | |||||
| tells the engine that the destination operations should depend on the source operation which means the source | tells the engine that the destination operations should depend on the source operation which means the source | ||||
| operations should be executed before the destination. | operations should be executed before the destination. | ||||
| Note: | |||||
| This operation does not work in `PYNATIVE_MODE`. | |||||
| Args: | Args: | ||||
| depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0. | depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0. | ||||
| @@ -19,6 +19,8 @@ import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as ms | |||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| @@ -116,3 +118,28 @@ def test_parser_map_0002(): | |||||
| net = NetMap0002() | net = NetMap0002() | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| net(input_me_x) | net(input_me_x) | ||||
| def test_fix_expanddims_loss_scale(): | |||||
| class ControlOneIfOneScaleOneScale(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.op = P.ExpandDims() | |||||
| def construct(self, x, y, data): | |||||
| if x > y: | |||||
| out = 1 | |||||
| else: | |||||
| out = 2 | |||||
| if x > y: | |||||
| out = self.op(data, out) | |||||
| else: | |||||
| out = self.op(data, out) | |||||
| return out | |||||
| net = ControlOneIfOneScaleOneScale() | |||||
| x = Tensor(1, ms.float32) | |||||
| y = Tensor(0, ms.float32) | |||||
| input_shape = (1024, 512, 7, 7) | |||||
| input_data = np.random.randn(*input_shape).astype(np.float32) | |||||
| net = ControlOneIfOneScaleOneScale() | |||||
| net(x, y, Tensor(input_data)) | |||||