Browse Source

add slice shape for param info

tags/v0.3.0-alpha
yangzhenzhang 5 years ago
parent
commit
05fde3d23d
3 changed files with 12 additions and 11 deletions
  1. +2
    -1
      mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc
  2. +8
    -8
      mindspore/parallel/_tensor.py
  3. +2
    -2
      tests/ut/python/parallel/test_get_parameter_layout.py

+ 2
- 1
mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc View File

@@ -42,7 +42,8 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
} else { } else {
auto device_arrangement = tensor_layout->device_arrangement().array(); auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array(); auto tensor_map = tensor_layout->tensor_map().array();
std::pair<std::vector<int32_t>, std::vector<int32_t>> layout(device_arrangement, tensor_map);
auto slice_shape = tensor_layout->slice_shape().array();
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape};
dict[py::str(name)] = layout; dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
} }


+ 8
- 8
mindspore/parallel/_tensor.py View File

@@ -203,19 +203,19 @@ def _load_tensor_by_layout(tensor, layout):


Args: Args:
tensor (Tensor): The input tensor. tensor (Tensor): The input tensor.
layout (tuple): The tensor layout in auto parallel.
layout (list): The tensor layout in auto parallel.


Returns: Returns:
Tensor, the sliced tensor..
Tensor, the sliced tensor.


Raises: Raises:
TypeError: If layout is not tuple.
ValueError: If the length of layout is not 2.
TypeError: If layout is not list.
ValueError: If the length of layout is not 3.
""" """
if not isinstance(layout, tuple):
raise TypeError("layout should be tuple! layout is {}".format(layout))
if len(layout) != 2:
raise ValueError("The length of layout must be 2! layout is {}".format(layout))
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}".format(layout))
dev_mat = layout[0] dev_mat = layout[0]
tensor_map = layout[1] tensor_map = layout[1]
if tensor.size() == 1: if tensor.size() == 1:


+ 2
- 2
tests/ut/python/parallel/test_get_parameter_layout.py View File

@@ -48,8 +48,8 @@ def test_get_parameter_layout():
net.set_auto_parallel() net.set_auto_parallel()
exe = me._executor exe = me._executor
exe.compile(net, x, auto_parallel_mode=True) exe.compile(net, x, auto_parallel_mode=True)
x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1]
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout} expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert (net.parameter_layout_dict == expect_dict) assert (net.parameter_layout_dict == expect_dict)


Loading…
Cancel
Save