Browse Source

support reshape parameter

tags/v0.3.0-alpha
lichenever 5 years ago
parent
commit
2ab211ae04
3 changed files with 104 additions and 3 deletions
  1. +24
    -1
      mindspore/ccsrc/parallel/step_parallel.cc
  2. +5
    -2
      mindspore/context.py
  3. +75
    -0
      tests/ut/python/parallel/test_reshape_parameter.py

+ 24
- 1
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -1523,9 +1523,32 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return nullptr;
}

std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
// Create DataParallel tensor layout for parameter(support WideDeep).
CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
TensorLayout input_tensor_layout;
// create input_shape
Shapes inputs_shape = GetNodeShape(node);
Shape input_shape_array = inputs_shape[0];
if (input_shape_array.empty()) {
MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter.";
}
// create tensor_map
size_t shape_size = input_shape_array.size();
TensorMap input_tensor_map_array(SizeToInt(shape_size) - 1, -1);
input_tensor_map_array.insert(input_tensor_map_array.begin(), 0);
// create dev_matrix
Shape dev_matrix_array = {dev_num};
if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) {
MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed.";
}
return std::make_shared<TensorLayout>(input_tensor_layout);
}

std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Failure: parameter before reshape is not supported temporary";
return CreateParameterLayout(node);
}
if (!node->isa<CNode>()) {
return nullptr;


+ 5
- 2
mindspore/context.py View File

@@ -415,8 +415,11 @@ def set_auto_parallel_context(**kwargs):
Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support mirror_mean. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
cast_before_mirror. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".



+ 75
- 0
tests/ut/python/parallel/test_reshape_parameter.py View File

@@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss
import numpy as np


class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network

def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)


class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network

def construct(self, x, y):
return C.grad_all(self.network)(x, y)


class Net(nn.Cell):
def __init__(self, strategy):
super().__init__()
self.reshape = P.Reshape()
self.mul = P.Mul().set_strategy(strategy)
self.relu = P.ReLU()

def construct(self, x, y):
out = self.reshape(x, (10000, 36, 1))
out = self.mul(out, y)
out = self.relu(out)
return out


def test_reshape_parameter_data_parallel():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy = ((8, 1, 1), (8, 1, 1))
net = GradWrap(NetWithLoss(Net(strategy)))
x = Tensor(np.ones([10000, 36]), dtype=ms.float32)
y = Tensor(np.ones([10000, 36, 1]), dtype=ms.float32)
_executor.compile(net, x, y)


def test_reshape_parameter_model_parallel():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(strategy)))
x = Tensor(np.ones([10000, 36]), dtype=ms.float32)
y = Tensor(np.ones([10000, 36, 1]), dtype=ms.float32)
_executor.compile(net, x, y)

Loading…
Cancel
Save