| @@ -42,7 +42,8 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||||
| } else { | } else { | ||||
| auto device_arrangement = tensor_layout->device_arrangement().array(); | auto device_arrangement = tensor_layout->device_arrangement().array(); | ||||
| auto tensor_map = tensor_layout->tensor_map().array(); | auto tensor_map = tensor_layout->tensor_map().array(); | ||||
| std::pair<std::vector<int32_t>, std::vector<int32_t>> layout(device_arrangement, tensor_map); | |||||
| auto slice_shape = tensor_layout->slice_shape().array(); | |||||
| std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape}; | |||||
| dict[py::str(name)] = layout; | dict[py::str(name)] = layout; | ||||
| MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | ||||
| } | } | ||||
| @@ -203,19 +203,19 @@ def _load_tensor_by_layout(tensor, layout): | |||||
| Args: | Args: | ||||
| tensor (Tensor): The input tensor. | tensor (Tensor): The input tensor. | ||||
| layout (tuple): The tensor layout in auto parallel. | |||||
| layout (list): The tensor layout in auto parallel. | |||||
| Returns: | Returns: | ||||
| Tensor, the sliced tensor.. | |||||
| Tensor, the sliced tensor. | |||||
| Raises: | Raises: | ||||
| TypeError: If layout is not tuple. | |||||
| ValueError: If the length of layout is not 2. | |||||
| TypeError: If layout is not list. | |||||
| ValueError: If the length of layout is not 3. | |||||
| """ | """ | ||||
| if not isinstance(layout, tuple): | |||||
| raise TypeError("layout should be tuple! layout is {}".format(layout)) | |||||
| if len(layout) != 2: | |||||
| raise ValueError("The length of layout must be 2! layout is {}".format(layout)) | |||||
| if not isinstance(layout, list): | |||||
| raise TypeError("The layout should be list! layout is {}".format(layout)) | |||||
| if len(layout) != 3: | |||||
| raise ValueError("The length of layout must be 3! layout is {}".format(layout)) | |||||
| dev_mat = layout[0] | dev_mat = layout[0] | ||||
| tensor_map = layout[1] | tensor_map = layout[1] | ||||
| if tensor.size() == 1: | if tensor.size() == 1: | ||||
| @@ -48,8 +48,8 @@ def test_get_parameter_layout(): | |||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| exe = me._executor | exe = me._executor | ||||
| exe.compile(net, x, auto_parallel_mode=True) | exe.compile(net, x, auto_parallel_mode=True) | ||||
| x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1] | |||||
| weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1] | |||||
| x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||||
| weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | expect_dict = {'x': x_layout, 'w1': weight_layout} | ||||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | ||||
| assert (net.parameter_layout_dict == expect_dict) | assert (net.parameter_layout_dict == expect_dict) | ||||