From: @fangzehua Reviewed-by: @stsuteng Signed-off-by:tags/v1.1.0
| @@ -18,11 +18,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } | |||||
| void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| node_ = kernel_node; | |||||
| x_data_type_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| type_size_ = GetTypeByte(TypeIdToType(x_data_type_)); | |||||
| } | |||||
| bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | |||||
| if (inputs.empty() || outputs.empty()) { | if (inputs.empty() || outputs.empty()) { | ||||
| MS_LOG(EXCEPTION) << "input or output empty!"; | MS_LOG(EXCEPTION) << "input or output empty!"; | ||||
| } | } | ||||
| @@ -34,7 +40,10 @@ bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| return true; | return true; | ||||
| } | } | ||||
| size_t mem_bits = outputs[0]->size; | |||||
| size_t mem_bits = type_size_; | |||||
| for (size_t i = 0; i < x_shape.size(); ++i) { | |||||
| mem_bits *= x_shape[i]; | |||||
| } | |||||
| auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); | auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | ||||
| @@ -31,6 +31,11 @@ class ReshapeCPUKernel : public CPUKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | |||||
| CNodePtr node_ = nullptr; | |||||
| TypeId x_data_type_{kNumberTypeInt32}; | |||||
| size_t type_size_ = 4; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| @@ -566,54 +566,69 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string &op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto reshape = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||||
| auto input_shp = input->shape()->shape(); | |||||
| auto reshape_val = reshape->BuildValue(); | |||||
| if (reshape_val->isa<AnyValue>()) { | |||||
| MS_LOG(EXCEPTION) << "Input_shape can't be anything: " << args_spec_list[1]->ToString(); | |||||
| } | |||||
| auto reshape_val_data = reshape_val->cast<ValueTuplePtr>()->value(); | |||||
| ShapeVector reshape_vec; | |||||
| (void)std::transform(std::begin(reshape_val_data), std::end(reshape_val_data), std::back_inserter(reshape_vec), | |||||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||||
| ShapeVector result_shp; | |||||
| auto input_prod = input_shp[0]; | |||||
| int64_t dim_prod = 1; | |||||
| size_t neg_idx = 0; | |||||
| for (size_t i = 1; i < input_shp.size(); i++) { | |||||
| input_prod *= input_shp[i]; | |||||
| } | |||||
| auto num_neg_one = count(std::begin(reshape_vec), std::end(reshape_vec), -1); | |||||
| if (num_neg_one > 1) { | |||||
| MS_LOG(EXCEPTION) << "The shape can only has one -1 at most, but " << num_neg_one; | |||||
| } | |||||
| for (size_t i = 0; i < reshape_vec.size(); i++) { | |||||
| if (reshape_vec[i] == -1) { | |||||
| neg_idx = i; | |||||
| result_shp.push_back(-1); | |||||
| } else { | |||||
| dim_prod *= reshape_vec[i]; | |||||
| result_shp.push_back(reshape_vec[i]); | |||||
| const std::string op_name = primitive->name(); | |||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(x); | |||||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||||
| ShapeVector shape; | |||||
| ShapeVector x_shape = x->shape()->shape(); | |||||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||||
| if (x_max_shape.empty()) { | |||||
| x_max_shape = x_shape; | |||||
| } | |||||
| if (x_min_shape.empty()) { | |||||
| x_min_shape = x_shape; | |||||
| } | |||||
| ValuePtr sh = primitive->GetAttr("shape"); | |||||
| auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(reshape_value_tuple); | |||||
| auto reshape_tuple = reshape_value_tuple->value(); | |||||
| (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), | |||||
| [](const ValuePtr &e) -> int { return GetValue<int>(e); }); | |||||
| auto max_shape = shape; | |||||
| auto min_shape = shape; | |||||
| int x_num = 1; | |||||
| int x_min_num = 1; | |||||
| int x_max_num = 1; | |||||
| for (int value : x_shape) { | |||||
| x_num = IntMulWithOverflowCheck(value, x_num); | |||||
| } | |||||
| for (int value : x_min_shape) { | |||||
| x_min_num = IntMulWithOverflowCheck(value, x_min_num); | |||||
| } | |||||
| for (int value : x_max_shape) { | |||||
| x_max_num = IntMulWithOverflowCheck(value, x_max_num); | |||||
| } | |||||
| auto it_first = find(shape.begin(), shape.end(), -1); | |||||
| if (it_first != shape.end()) { | |||||
| auto it_second = find(it_first + 1, shape.end(), -1); | |||||
| if (it_second != shape.end()) { | |||||
| MS_LOG(EXCEPTION) << "At most one component of input shape can be -1"; | |||||
| } | } | ||||
| int index = std::distance(it_first, shape.begin()); | |||||
| int infer_value = x_num; | |||||
| int infer_min_value = x_min_num; | |||||
| int infer_max_value = x_max_num; | |||||
| for (size_t i = 0; i < shape.size(); ++i) { | |||||
| int value = shape[i]; | |||||
| if (value != -1 && value != 0) { | |||||
| infer_value = infer_value / value; | |||||
| infer_min_value = infer_min_value / value; | |||||
| infer_max_value = infer_max_value / value; | |||||
| } | |||||
| } | |||||
| shape[index] = infer_value; | |||||
| min_shape[index] = infer_min_value; | |||||
| max_shape[index] = infer_max_value; | |||||
| } | } | ||||
| if (dim_prod < 0 || input_prod % dim_prod != 0) { | |||||
| MS_LOG(EXCEPTION) << "The input_x shape product is " << input_prod << ", input_shape shape product is " << dim_prod | |||||
| << ", and this value should be > 0 and should divide product of input_x."; | |||||
| } | |||||
| if (num_neg_one == 1) { | |||||
| int64_t val = static_cast<int64_t>(input_prod) / dim_prod; | |||||
| dim_prod *= val; | |||||
| result_shp[neg_idx] = val; | |||||
| } | |||||
| if (dim_prod != input_prod) { | |||||
| MS_LOG(EXCEPTION) | |||||
| << "The product of input_x shape should be equal to product of input_shape shape, but input_x shape is " | |||||
| << input_prod << ", product of input_shape shape is " << dim_prod; | |||||
| } | |||||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp)); | |||||
| AbstractTensorPtr ret = | |||||
| std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| return ret; | |||||
| } | } | ||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -445,6 +445,14 @@ class Reshape(PrimitiveWithInfer): | |||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | ||||
| validator.check_value_type("shape", shape_v, [tuple], self.name) | validator.check_value_type("shape", shape_v, [tuple], self.name) | ||||
| shape_v = list(shape_v) | shape_v = list(shape_v) | ||||
| if 'max_shape' in x: | |||||
| x_max_shape = x['max_shape'] | |||||
| else: | |||||
| x_max_shape = x['shape'] | |||||
| if 'min_shape' in x: | |||||
| x_min_shape = x['min_shape'] | |||||
| else: | |||||
| x_min_shape = x['shape'] | |||||
| neg_index = -1 | neg_index = -1 | ||||
| dim_prod = 1 | dim_prod = 1 | ||||
| for i, shp_i in enumerate(shape_v): | for i, shp_i in enumerate(shape_v): | ||||
| @@ -456,14 +464,19 @@ class Reshape(PrimitiveWithInfer): | |||||
| else: | else: | ||||
| dim_prod *= shp_i | dim_prod *= shp_i | ||||
| arr_prod = np.prod(x_shp) | arr_prod = np.prod(x_shp) | ||||
| max_arr_prod = np.prod(x_max_shape) | |||||
| min_arr_prod = np.prod(x_min_shape) | |||||
| if dim_prod <= 0 or arr_prod % dim_prod != 0: | if dim_prod <= 0 or arr_prod % dim_prod != 0: | ||||
| raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.' | raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.' | ||||
| f'The product of input_x\'s shape should > 0, ' | f'The product of input_x\'s shape should > 0, ' | ||||
| f'and can be divided by product of input_shape, ' | f'and can be divided by product of input_shape, ' | ||||
| f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.') | f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.') | ||||
| max_shape = list(shape_v) | |||||
| min_shape = list(shape_v) | |||||
| if neg_index != -1: | if neg_index != -1: | ||||
| shape_v[neg_index] = int(arr_prod / dim_prod) | shape_v[neg_index] = int(arr_prod / dim_prod) | ||||
| max_shape[neg_index] = int(max_arr_prod / dim_prod) | |||||
| min_shape[neg_index] = int(min_arr_prod / dim_prod) | |||||
| dim_prod *= shape_v[neg_index] | dim_prod *= shape_v[neg_index] | ||||
| if dim_prod != arr_prod: | if dim_prod != arr_prod: | ||||
| raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.' | raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.' | ||||
| @@ -476,7 +489,9 @@ class Reshape(PrimitiveWithInfer): | |||||
| out = {'shape': tuple(shape_v), | out = {'shape': tuple(shape_v), | ||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| 'value': value} | |||||
| 'value': value, | |||||
| 'max_shape': tuple(max_shape), | |||||
| 'min_shape': tuple(min_shape)} | |||||
| return out | return out | ||||