Browse Source

!5722 fix semi auto parallel parameter of reshape has another user

Merge pull request !5722 from yao_yf/semi_auto_parallel_reshape_parameter_has_another_user
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7786adc3aa
3 changed files with 52 additions and 0 deletions
  1. +28
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +2
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.h
  3. +22
    -0
      tests/ut/python/parallel/test_auto_parallel_reshape.py

+ 28
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -1645,8 +1645,36 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return nullptr;
}

std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
FuncGraphManagerPtr manager = node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[node];
for (auto &node_pair : node_set) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) {
continue;
}
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
}
return nullptr;
}

std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
// Create DataParallel tensor layout for parameter(support WideDeep).
auto next_layout = FindParameterNextLayout(node);
if (next_layout != nullptr) {
return next_layout;
}
CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
TensorLayout input_tensor_layout;


+ 2
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -156,6 +156,8 @@ using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeI

RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode);

std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);

ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
} // namespace parallel
} // namespace mindspore


+ 22
- 0
tests/ut/python/parallel/test_auto_parallel_reshape.py View File

@@ -292,3 +292,25 @@ def test_reshape_auto_6():
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)

def test_reshape_auto_7():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = P.Reshape()
self.mul = P.Mul().set_strategy(((1, 2, 4), (2, 4)))
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")

def construct(self, x):
weight = self.reshape(self.mul_weight, (1, 128, 96))
out = self.mul(weight, self.mul_weight)
return out

size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([128, 28]), dtype=ms.float32)

net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)

Loading…
Cancel
Save