From 212f65fc02a9a91fb27e7915bf257b695026c4f4 Mon Sep 17 00:00:00 2001 From: fangzehua Date: Fri, 13 Nov 2020 16:16:57 +0800 Subject: [PATCH] add reshape dynamic --- .../kernel_compiler/cpu/reshape_cpu_kernel.cc | 13 ++- .../kernel_compiler/cpu/reshape_cpu_kernel.h | 5 + mindspore/core/abstract/prim_arrays.cc | 99 +++++++++++-------- mindspore/ops/operations/array_ops.py | 19 +++- 4 files changed, 90 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc index e0183307f5..022b252764 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc @@ -18,11 +18,17 @@ namespace mindspore { namespace kernel { -void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } +void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + node_ = kernel_node; + x_data_type_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + type_size_ = GetTypeByte(TypeIdToType(x_data_type_)); +} bool ReshapeCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); if (inputs.empty() || outputs.empty()) { MS_LOG(EXCEPTION) << "input or output empty!"; } @@ -34,7 +40,10 @@ bool ReshapeCPUKernel::Launch(const std::vector &inputs, return true; } - size_t mem_bits = outputs[0]->size; + size_t mem_bits = type_size_; + for (size_t i = 0; i < x_shape.size(); ++i) { + mem_bits *= x_shape[i]; + } auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h index e7a875b1c0..162c40d2c3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -31,6 +31,11 @@ class ReshapeCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; + + private: + CNodePtr node_ = nullptr; + TypeId x_data_type_{kNumberTypeInt32}; + size_t type_size_ = 4; }; MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index c08f59fad2..c92b013156 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -436,54 +436,69 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - const std::string &op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); - auto reshape = CheckArg(op_name, args_spec_list, 1); - auto input_shp = input->shape()->shape(); - auto reshape_val = reshape->BuildValue(); - if (reshape_val->isa()) { - MS_LOG(EXCEPTION) << "Input_shape can't be anything: " << args_spec_list[1]->ToString(); - } - auto reshape_val_data = reshape_val->cast()->value(); - ShapeVector reshape_vec; - (void)std::transform(std::begin(reshape_val_data), std::end(reshape_val_data), std::back_inserter(reshape_vec), - [](const ValuePtr &e) -> int64_t { return GetValue(e); }); - ShapeVector result_shp; - auto input_prod = input_shp[0]; - int64_t dim_prod = 1; - size_t neg_idx = 0; - for (size_t i = 1; i < input_shp.size(); i++) { - input_prod *= input_shp[i]; + const std::string op_name = primitive->name(); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + ShapeVector shape; + ShapeVector x_shape = x->shape()->shape(); + ShapeVector x_max_shape = x->shape()->max_shape(); + ShapeVector x_min_shape = x->shape()->min_shape(); + if (x_max_shape.empty()) { + x_max_shape = x_shape; } - auto num_neg_one = count(std::begin(reshape_vec), std::end(reshape_vec), -1); - if (num_neg_one > 1) { - MS_LOG(EXCEPTION) << "The shape can only has one -1 at most, but " << num_neg_one; + if (x_min_shape.empty()) { + x_min_shape = x_shape; } - for (size_t i = 0; i < reshape_vec.size(); i++) { - if (reshape_vec[i] == -1) { - neg_idx = i; - result_shp.push_back(-1); - } else { - dim_prod *= reshape_vec[i]; - result_shp.push_back(reshape_vec[i]); - } + ValuePtr sh = primitive->GetAttr("shape"); + auto reshape_value_tuple = sh->cast(); + MS_EXCEPTION_IF_NULL(reshape_value_tuple); + auto reshape_tuple = reshape_value_tuple->value(); + + (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), + [](const ValuePtr &e) -> int { return GetValue(e); }); + + auto max_shape = shape; + auto min_shape = shape; + int x_num = 1; + int x_min_num = 1; + int x_max_num = 1; + for (int value : x_shape) { + x_num = IntMulWithOverflowCheck(value, x_num); } - if (dim_prod < 0 || input_prod % dim_prod != 0) { - MS_LOG(EXCEPTION) << "The input_x shape product is " << input_prod << ", input_shape shape product is " << dim_prod - << ", and this value should be > 0 and should divide product of input_x."; + for (int value : x_min_shape) { + x_min_num = IntMulWithOverflowCheck(value, x_min_num); } - if (num_neg_one == 1) { - int64_t val = static_cast(input_prod) / dim_prod; - dim_prod *= val; - result_shp[neg_idx] = val; + for (int value : x_max_shape) { + x_max_num = IntMulWithOverflowCheck(value, x_max_num); } - if (dim_prod != input_prod) { - MS_LOG(EXCEPTION) - << "The product of input_x shape should be equal to product of input_shape shape, but input_x shape is " - << input_prod << ", product of input_shape shape is " << dim_prod; + + auto it_first = find(shape.begin(), shape.end(), -1); + if (it_first != shape.end()) { + auto it_second = find(it_first + 1, shape.end(), -1); + if (it_second != shape.end()) { + MS_LOG(EXCEPTION) << "At most one component of input shape can be -1"; + } + int index = std::distance(it_first, shape.begin()); + int infer_value = x_num; + int infer_min_value = x_min_num; + int infer_max_value = x_max_num; + for (size_t i = 0; i < shape.size(); ++i) { + int value = shape[i]; + if (value != -1 && value != 0) { + infer_value = infer_value / value; + infer_min_value = infer_min_value / value; + infer_max_value = infer_max_value / value; + } + } + shape[index] = infer_value; + min_shape[index] = infer_min_value; + max_shape[index] = infer_max_value; } - return std::make_shared(input->element(), std::make_shared(result_shp)); + + AbstractTensorPtr ret = + std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); + return ret; } } // namespace abstract } // namespace mindspore diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 10753f4a0f..8f7aed7e7c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -440,6 +440,14 @@ class Reshape(PrimitiveWithInfer): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) validator.check_value_type("shape", shape_v, [tuple], self.name) shape_v = list(shape_v) + if 'max_shape' in x: + x_max_shape = x['max_shape'] + else: + x_max_shape = x['shape'] + if 'min_shape' in x: + x_min_shape = x['min_shape'] + else: + x_min_shape = x['shape'] neg_index = -1 dim_prod = 1 for i, shp_i in enumerate(shape_v): @@ -451,14 +459,19 @@ class Reshape(PrimitiveWithInfer): else: dim_prod *= shp_i arr_prod = np.prod(x_shp) + max_arr_prod = np.prod(x_max_shape) + min_arr_prod = np.prod(x_min_shape) if dim_prod <= 0 or arr_prod % dim_prod != 0: raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.' f'The product of input_x\'s shape should > 0, ' f'and can be divided by product of input_shape, ' f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.') - + max_shape = list(shape_v) + min_shape = list(shape_v) if neg_index != -1: shape_v[neg_index] = int(arr_prod / dim_prod) + max_shape[neg_index] = int(max_arr_prod / dim_prod) + min_shape[neg_index] = int(min_arr_prod / dim_prod) dim_prod *= shape_v[neg_index] if dim_prod != arr_prod: raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.' @@ -471,7 +484,9 @@ class Reshape(PrimitiveWithInfer): out = {'shape': tuple(shape_v), 'dtype': x['dtype'], - 'value': value} + 'value': value, + 'max_shape': tuple(max_shape), + 'min_shape': tuple(min_shape)} return out