| @@ -35,14 +35,7 @@ Status ResizeBilinearInfo::GetAttrs() { | |||
| return FAILED; | |||
| } | |||
| if (size_[0] != size_[1]) { | |||
| MS_LOG(ERROR) << name_ << ": The second two elements of size must be the same, but got (" << size_[0] << ", " | |||
| << size_[1] << ")"; | |||
| return FAILED; | |||
| } | |||
| align_corners_ = GetBoolAttr(ALIGN_CORNERS); | |||
| MS_LOG(INFO) << name_ << ": The input size is " << size_ << ", align_corners is " << align_corners_; | |||
| return SUCCESS; | |||
| @@ -85,7 +78,15 @@ Status ResizeBilinearInfo::InferDevMatrixShape() { | |||
| return FAILED; | |||
| } | |||
| if (stra[0].size() != 4) { | |||
| MS_LOG(ERROR) << name_ << ": The size of strategy must be 4, but got " << stra[0].size(); | |||
| return FAILED; | |||
| } | |||
| dev_matrix_shape_ = stra[0]; | |||
| slice_size_ = size_; | |||
| slice_size_[0] = slice_size_[0] / dev_matrix_shape_[2]; | |||
| slice_size_[1] = slice_size_[1] / dev_matrix_shape_[3]; | |||
| return SUCCESS; | |||
| } | |||
| @@ -135,5 +136,40 @@ Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0)); | |||
| prim->set_attr(SIZE, MakeValue(slice_size_)); | |||
| } | |||
| Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| MS_EXCEPTION_IF_NULL(strategy); | |||
| // check input strategy | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Check input strategy failed"; | |||
| return FAILED; | |||
| } | |||
| // check output strategy | |||
| if (CheckStrategyValue(strategy, outputs_shape_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Check output strategy failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| Shape multiples_split(inputs_shape_[0].size(), 1); | |||
| Shapes splittable_inputs = {multiples_split}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": generate strategies failed"; | |||
| } | |||
| return sp_vector; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -47,11 +47,11 @@ class ResizeBilinearInfo : public OperatorInfo { | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| private: | |||
| std::vector<int64_t> size_; // four integers, NCHW | |||
| bool align_corners_; | |||
| std::vector<int64_t> size_; | |||
| std::vector<int64_t> slice_size_; | |||
| bool align_corners_ = false; | |||
| }; | |||
| class ResizeNearestNeighborInfo : public ResizeBilinearInfo { | |||
| @@ -60,6 +60,10 @@ class ResizeNearestNeighborInfo : public ResizeBilinearInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : ResizeBilinearInfo(name, inputs_shape, outputs_shape, attrs) {} | |||
| ~ResizeNearestNeighborInfo() override = default; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -3830,7 +3830,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name) | |||
| return tuple(x_shape)[:-2] + tuple(self.size) | |||
| return tuple(x_shape)[:-2] + tuple(super().get_attr_dict()['size']) | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||
| @@ -43,17 +43,20 @@ class Net2(Cell): | |||
| ''' | |||
| create the test Net | |||
| ''' | |||
| def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, | |||
| def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride, | |||
| strategy1=None, strategy2=None): | |||
| super(Net2, self).__init__() | |||
| self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, | |||
| pad_mode=pad_mode, stride=stride).shard(strategy1) | |||
| self.conv2d_weight = Parameter(conv2d_weight, "w1") | |||
| self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2) | |||
| self.mul = P.Mul() | |||
| self.mul_weight = Parameter(mul_weight, "w2") | |||
| def construct(self, x): | |||
| out = self.conv2d(x, self.conv2d_weight) | |||
| out = self.resize_neighbor(out) | |||
| out = self.mul(out, self.mul_weight) | |||
| return out | |||
| class Net3(Cell): | |||
| @@ -76,6 +79,7 @@ class Net3(Cell): | |||
| _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) | |||
| _w2 = Tensor(np.ones([32, 8, 16, 16]), dtype=ms.float32) | |||
| def compile_net(net, inputs=_x): | |||
| @@ -130,21 +134,21 @@ def test_neighbor_data_parallel(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) | |||
| strategy2 = ((8, 1, 1, 1),) | |||
| net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, | |||
| net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| compile_net(net) | |||
| def test_neighbor_model_parallel1(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) | |||
| strategy2 = ((4, 2, 1, 1),) | |||
| net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, | |||
| strategy2 = ((2, 2, 2, 2),) | |||
| net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| compile_net(net) | |||
| def test_neighbor_auto_parallel(): | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||
| net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1) | |||
| net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1) | |||
| compile_net(net) | |||