Browse Source

fix refkey bug for auto parallel

tags/v0.3.0-alpha
lichenever chang zherui 6 years ago
parent
commit
89945c15f5
2 changed files with 50 additions and 7 deletions
  1. +19
    -2
      mindspore/ccsrc/parallel/step_parallel.cc
  2. +31
    -5
      tests/ut/python/parallel/test_arithmetic.py

+ 19
- 2
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -49,6 +49,9 @@ namespace mindspore {
namespace parallel {
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;

void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
@@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
std::vector<AnfNodePtr> all_inputs = node->inputs();
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};

for (auto& input : node_inputs) {
size_t inputs_size = all_inputs.size();
for (size_t i = 1; i < inputs_size; ++i) {
Shapes input_shapes;
AnfNodePtr input = all_inputs[i];
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
if (parameters.size() != 1) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
std::pair<AnfNodePtr, int> node_pair = std::make_pair(node, SizeToInt(i));
g_RefMap[parameters[0]] = node_pair;
input_shapes = GetRefKeyNodeShape(input, func_graph);
} else if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
input_shapes = GetNodeShape(input);
@@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
auto parameters = root->parameters();
for (auto& parameter : parameters) {
MS_EXCEPTION_IF_NULL(parameter->Shape());
auto iter = g_RefMap.find(parameter);
if (iter != g_RefMap.end()) {
SetParallelShape(parameter, g_RefMap[parameter]);
continue;
}
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
if (res.first == nullptr) {
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
} else {
SetParallelShape(parameter, res);
MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
}
}
g_RefMap.clear();
}

bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) {


+ 31
- 5
tests/ut/python/parallel/test_arithmetic.py View File

@@ -13,14 +13,13 @@
# limitations under the License.

import numpy as np
from mindspore import context
import mindspore as ms
from mindspore import Parameter, Tensor, context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from tests.ut.python.ops.test_math_ops import VirtualLoss
import mindspore as ms
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss


class NetWithLoss(nn.Cell):
@@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2():
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)


def test_assign_sub():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.AssignSub()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")

def construct(self, x, y, z):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
return out

context.set_auto_parallel_context(device_num=64, global_rank=15)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, z)

Loading…
Cancel
Save