| @@ -39,7 +39,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStr | |||||
| Status ReshapeInfo::InferDevMatrixShape() { | Status ReshapeInfo::InferDevMatrixShape() { | ||||
| Strategys stra = strategy_->GetInputDim(); | Strategys stra = strategy_->GetInputDim(); | ||||
| input_strategy_ = stra.at(0); | input_strategy_ = stra.at(0); | ||||
| dev_matrix_shape_.push_back(input_strategy_[0]); | |||||
| dev_matrix_shape_ = stra.at(0); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -162,17 +162,13 @@ Status ReshapeInfo::InferTensorMap() { | |||||
| } | } | ||||
| Shape tensor_map_index_input; | Shape tensor_map_index_input; | ||||
| tensor_map_index_input.push_back(0); | |||||
| for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { | |||||
| tensor_map_index_input.push_back(MAP_NONE); | |||||
| for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { | |||||
| tensor_map_index_input.push_back((int64_t)(inputs_shape_[0].size() - j - 1)); | |||||
| } | } | ||||
| inputs_tensor_map_.push_back(tensor_map_index_input); | inputs_tensor_map_.push_back(tensor_map_index_input); | ||||
| Shape tensor_map_index_output; | Shape tensor_map_index_output; | ||||
| tensor_map_index_output.push_back(0); | |||||
| for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { | |||||
| for (size_t j = 0; j < outputs_shape_[0].size(); ++j) { | |||||
| tensor_map_index_output.push_back(MAP_NONE); | tensor_map_index_output.push_back(MAP_NONE); | ||||
| } | } | ||||
| outputs_tensor_map_.push_back(tensor_map_index_output); | outputs_tensor_map_.push_back(tensor_map_index_output); | ||||
| @@ -186,8 +182,7 @@ Status ReshapeInfo::InferTensorMap() { | |||||
| Strategys ReshapeInfo::GetOutputsStrategy() { | Strategys ReshapeInfo::GetOutputsStrategy() { | ||||
| Strategys outputs_strategy; | Strategys outputs_strategy; | ||||
| Dimensions strategy; | Dimensions strategy; | ||||
| strategy.push_back(input_strategy_[0]); | |||||
| for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { | |||||
| for (size_t j = 0; j < outputs_shape_[0].size(); ++j) { | |||||
| strategy.push_back(1); | strategy.push_back(1); | ||||
| } | } | ||||
| outputs_strategy.push_back(strategy); | outputs_strategy.push_back(strategy); | ||||
| @@ -74,7 +74,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) { | |||||
| reshape->Init(strategy); | reshape->Init(strategy); | ||||
| Shape dev_matrix_shape = reshape->dev_matrix_shape(); | Shape dev_matrix_shape = reshape->dev_matrix_shape(); | ||||
| Shape expect = {4, 8}; | |||||
| Shape expect = {4, 1, 1, 1, 8}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -85,7 +85,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape2) { | |||||
| reshape->Init(strategy); | reshape->Init(strategy); | ||||
| Shape dev_matrix_shape = reshape->dev_matrix_shape(); | Shape dev_matrix_shape = reshape->dev_matrix_shape(); | ||||
| Shape expect = {32}; | |||||
| Shape expect = {32, 1, 1, 1}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) { | |||||
| std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | ||||
| Shape input_slice_shape_expect = {8, 512, 7, 7}; | Shape input_slice_shape_expect = {8, 512, 7, 7}; | ||||
| Shape output_slice_shape_expect = {8, 25088}; | |||||
| Shape output_slice_shape_expect = {32, 25088}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo output_tensor_info = outputs.at(0); | TensorInfo output_tensor_info = outputs.at(0); | ||||
| @@ -119,7 +119,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) { | |||||
| std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | ||||
| Shape input_slice_shape_expect = {1, 512, 7, 7}; | Shape input_slice_shape_expect = {1, 512, 7, 7}; | ||||
| Shape output_slice_shape_expect = {1, 25088}; | |||||
| Shape output_slice_shape_expect = {32, 25088}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo output_tensor_info = outputs.at(0); | TensorInfo output_tensor_info = outputs.at(0); | ||||
| @@ -139,8 +139,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) { | |||||
| std::vector<TensorInfo> inputs = reshape->inputs_tensor_info(); | std::vector<TensorInfo> inputs = reshape->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | ||||
| TensorMap input_expect = {1, -1, -1, -1}; | |||||
| TensorMap output_expect = {1, -1}; | |||||
| TensorMap input_expect = {4, 3, 2, 1}; | |||||
| TensorMap output_expect = {-1, -1}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo output_tensor_info = outputs.at(0); | TensorInfo output_tensor_info = outputs.at(0); | ||||
| @@ -160,8 +160,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) { | |||||
| std::vector<TensorInfo> inputs = reshape->inputs_tensor_info(); | std::vector<TensorInfo> inputs = reshape->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | ||||
| TensorMap input_expect = {0, -1, -1, -1}; | |||||
| TensorMap output_expect = {0, -1}; | |||||
| TensorMap input_expect = {3, 2, 1, 0}; | |||||
| TensorMap output_expect = {-1, -1}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo output_tensor_info = outputs.at(0); | TensorInfo output_tensor_info = outputs.at(0); | ||||
| @@ -74,12 +74,12 @@ def test_reshape_unexpand_1(): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) | |||||
| self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") | |||||
| self.mul = P.Mul().shard(((1, 1, 8), (1, 8))) | |||||
| self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| weight = self.reshape(self.mul_weight, (1, 128, 96)) | |||||
| out = self.mul(x, weight) | |||||
| def construct(self, data): | |||||
| x = self.reshape(self.mul_weight, (1, 128, 96)) | |||||
| out = self.mul(x, self.mul_weight) | |||||
| return out | return out | ||||
| size = 8 | size = 8 | ||||
| @@ -236,3 +236,25 @@ def test_reshape_unexpand_7(): | |||||
| net = GradWrap(NetWithLoss(Net())) | net = GradWrap(NetWithLoss(Net())) | ||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| _executor.compile(net, x) | _executor.compile(net, x) | ||||
| def test_reshape_unexpand_8(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.reshape = P.Reshape() | |||||
| self.mul = P.Mul().shard(((1, 4, 2), (4, 2))) | |||||
| self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") | |||||
| def construct(self, data): | |||||
| x = self.reshape(self.mul_weight, (1, 128, 96)) | |||||
| out = self.mul(x, self.mul_weight) | |||||
| return out | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([128, 96]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| net.set_auto_parallel() | |||||
| _executor.compile(net, x) | |||||