Browse Source

回退 'Pull Request !17 : [AutoParallel]Fix bug in the case of two cast'

tags/v0.3.0-alpha
leonwanghui chang zherui 6 years ago
parent
commit
4c2aa41f1d
3 changed files with 4 additions and 38 deletions
  1. +0
    -2
      mindspore/ccsrc/parallel/step_auto_parallel.cc
  2. +4
    -7
      mindspore/ccsrc/parallel/step_parallel.cc
  3. +0
    -29
      tests/ut/python/parallel/test_element_wise_function.py

+ 0
- 2
mindspore/ccsrc/parallel/step_auto_parallel.cc View File

@@ -350,8 +350,6 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
}

OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode);
auto attrs = prim->attrs();
std::vector<Shapes> shape_list = ExtractShape(cnode);
if (shape_list.empty()) {


+ 4
- 7
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -374,6 +374,7 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
if (prim == nullptr) {
return false;
}
auto attrs = prim->attrs();
if (IsInBlackList(prim)) {
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
return false;
@@ -653,13 +654,6 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) {
MS_EXCEPTION_IF_NULL(pre_node);

LossNodeInfo node_info;
// return -> cast
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
pre_node = pre_cnode->input(1);
}

// return -> cast
auto pre_cnode = pre_node->cast<CNodePtr>();
@@ -1978,6 +1972,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(current_prim);
<<<<<<< HEAD
<<<<<<< HEAD

=======
>>>>>>> fix_cast_bug
@@ -1988,6 +1983,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
}

=======
>>>>>>> 回退 'Pull Request !17 : [AutoParallel]Fix bug in the case of two cast'
// notice: the GetNext op has not input
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
MS_LOG(INFO) << "The loss is: " << current_prim->name();


+ 0
- 29
tests/ut/python/parallel/test_element_wise_function.py View File

@@ -268,32 +268,3 @@ def test_cast_before_mirror3():
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)


def test_mul_two_cast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.mul2 = P.Mul().set_strategy(strategy2)
self.cast = P.Cast().set_strategy(strategy3)
self.cast2 = P.Cast().set_strategy(strategy3)

def construct(self, x, y, b):
out = self.mul(x, y)
out = self.mul2(out, b)
out = self.cast(out, ms.int32)
out = self.cast2(out, ms.bool_)
return out

context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((8, 1), (8, 1))
strategy3 = ((8, 1), )
net = GradWrap(Net(strategy1, strategy2, strategy3))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")

x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, b)

Loading…
Cancel
Save