| @@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| return false; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| if (param_type == "dynamic") { | |||
| if (dyn_input_sizes.empty()) { | |||
| MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; | |||
| @@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| dyn_input_idx++; | |||
| } else if (param_type == "required") { | |||
| @@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } else { | |||
| if (kernel_info_index < real_input_num) { | |||
| MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; | |||
| @@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| builder->SetInputReshapeType(reshape_types); | |||
| @@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||
| MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; | |||
| continue; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| size_t output_num = 0; | |||
| if (output->param_type() == "dynamic") { | |||
| if (outputs.size() > 1) { | |||
| @@ -467,13 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| outputs_device_type.push_back(type_id); | |||
| outputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| output_idx++; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| builder->SetOutputReshapeType(reshape_types); | |||