Browse Source

!8582 add reshape dynamic

From: @fangzehua
Reviewed-by: @stsuteng
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0ab808ec9e
4 changed files with 94 additions and 50 deletions
  1. +11
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc
  2. +5
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h
  3. +61
    -46
      mindspore/core/abstract/prim_arrays.cc
  4. +17
    -2
      mindspore/ops/operations/array_ops.py

+ 11
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc View File

@@ -18,11 +18,17 @@

namespace mindspore {
namespace kernel {
void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); }
void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
node_ = kernel_node;
x_data_type_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
type_size_ = GetTypeByte(TypeIdToType(x_data_type_));
}

bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
if (inputs.empty() || outputs.empty()) {
MS_LOG(EXCEPTION) << "input or output empty!";
}
@@ -34,7 +40,10 @@ bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true;
}

size_t mem_bits = outputs[0]->size;
size_t mem_bits = type_size_;
for (size_t i = 0; i < x_shape.size(); ++i) {
mem_bits *= x_shape[i];
}
auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;


+ 5
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h View File

@@ -31,6 +31,11 @@ class ReshapeCPUKernel : public CPUKernel {

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

private:
CNodePtr node_ = nullptr;
TypeId x_data_type_{kNumberTypeInt32};
size_t type_size_ = 4;
};

MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),


+ 61
- 46
mindspore/core/abstract/prim_arrays.cc View File

@@ -566,54 +566,69 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr

AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto reshape = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
auto input_shp = input->shape()->shape();
auto reshape_val = reshape->BuildValue();
if (reshape_val->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Input_shape can't be anything: " << args_spec_list[1]->ToString();
}
auto reshape_val_data = reshape_val->cast<ValueTuplePtr>()->value();
ShapeVector reshape_vec;
(void)std::transform(std::begin(reshape_val_data), std::end(reshape_val_data), std::back_inserter(reshape_vec),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
ShapeVector result_shp;
auto input_prod = input_shp[0];
int64_t dim_prod = 1;
size_t neg_idx = 0;
for (size_t i = 1; i < input_shp.size(); i++) {
input_prod *= input_shp[i];
}
auto num_neg_one = count(std::begin(reshape_vec), std::end(reshape_vec), -1);
if (num_neg_one > 1) {
MS_LOG(EXCEPTION) << "The shape can only has one -1 at most, but " << num_neg_one;
}
for (size_t i = 0; i < reshape_vec.size(); i++) {
if (reshape_vec[i] == -1) {
neg_idx = i;
result_shp.push_back(-1);
} else {
dim_prod *= reshape_vec[i];
result_shp.push_back(reshape_vec[i]);
const std::string op_name = primitive->name();
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector shape;
ShapeVector x_shape = x->shape()->shape();
ShapeVector x_max_shape = x->shape()->max_shape();
ShapeVector x_min_shape = x->shape()->min_shape();
if (x_max_shape.empty()) {
x_max_shape = x_shape;
}
if (x_min_shape.empty()) {
x_min_shape = x_shape;
}
ValuePtr sh = primitive->GetAttr("shape");
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
auto reshape_tuple = reshape_value_tuple->value();

(void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape),
[](const ValuePtr &e) -> int { return GetValue<int>(e); });

auto max_shape = shape;
auto min_shape = shape;
int x_num = 1;
int x_min_num = 1;
int x_max_num = 1;
for (int value : x_shape) {
x_num = IntMulWithOverflowCheck(value, x_num);
}
for (int value : x_min_shape) {
x_min_num = IntMulWithOverflowCheck(value, x_min_num);
}
for (int value : x_max_shape) {
x_max_num = IntMulWithOverflowCheck(value, x_max_num);
}

auto it_first = find(shape.begin(), shape.end(), -1);
if (it_first != shape.end()) {
auto it_second = find(it_first + 1, shape.end(), -1);
if (it_second != shape.end()) {
MS_LOG(EXCEPTION) << "At most one component of input shape can be -1";
}
int index = std::distance(it_first, shape.begin());
int infer_value = x_num;
int infer_min_value = x_min_num;
int infer_max_value = x_max_num;
for (size_t i = 0; i < shape.size(); ++i) {
int value = shape[i];
if (value != -1 && value != 0) {
infer_value = infer_value / value;
infer_min_value = infer_min_value / value;
infer_max_value = infer_max_value / value;
}
}
shape[index] = infer_value;
min_shape[index] = infer_min_value;
max_shape[index] = infer_max_value;
}
if (dim_prod < 0 || input_prod % dim_prod != 0) {
MS_LOG(EXCEPTION) << "The input_x shape product is " << input_prod << ", input_shape shape product is " << dim_prod
<< ", and this value should be > 0 and should divide product of input_x.";
}
if (num_neg_one == 1) {
int64_t val = static_cast<int64_t>(input_prod) / dim_prod;
dim_prod *= val;
result_shp[neg_idx] = val;
}
if (dim_prod != input_prod) {
MS_LOG(EXCEPTION)
<< "The product of input_x shape should be equal to product of input_shape shape, but input_x shape is "
<< input_prod << ", product of input_shape shape is " << dim_prod;
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));

AbstractTensorPtr ret =
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
}
} // namespace abstract
} // namespace mindspore

+ 17
- 2
mindspore/ops/operations/array_ops.py View File

@@ -445,6 +445,14 @@ class Reshape(PrimitiveWithInfer):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_value_type("shape", shape_v, [tuple], self.name)
shape_v = list(shape_v)
if 'max_shape' in x:
x_max_shape = x['max_shape']
else:
x_max_shape = x['shape']
if 'min_shape' in x:
x_min_shape = x['min_shape']
else:
x_min_shape = x['shape']
neg_index = -1
dim_prod = 1
for i, shp_i in enumerate(shape_v):
@@ -456,14 +464,19 @@ class Reshape(PrimitiveWithInfer):
else:
dim_prod *= shp_i
arr_prod = np.prod(x_shp)
max_arr_prod = np.prod(x_max_shape)
min_arr_prod = np.prod(x_min_shape)
if dim_prod <= 0 or arr_prod % dim_prod != 0:
raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
f'The product of input_x\'s shape should > 0, '
f'and can be divided by product of input_shape, '
f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')

max_shape = list(shape_v)
min_shape = list(shape_v)
if neg_index != -1:
shape_v[neg_index] = int(arr_prod / dim_prod)
max_shape[neg_index] = int(max_arr_prod / dim_prod)
min_shape[neg_index] = int(min_arr_prod / dim_prod)
dim_prod *= shape_v[neg_index]
if dim_prod != arr_prod:
raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
@@ -476,7 +489,9 @@ class Reshape(PrimitiveWithInfer):

out = {'shape': tuple(shape_v),
'dtype': x['dtype'],
'value': value}
'value': value,
'max_shape': tuple(max_shape),
'min_shape': tuple(min_shape)}
return out




Loading…
Cancel
Save