|
|
|
@@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<Axis> reshape_type; |
|
|
|
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (param_type == "dynamic") { |
|
|
|
if (dyn_input_sizes.empty()) { |
|
|
|
MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; |
|
|
|
@@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp |
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); |
|
|
|
inputs_device_type.push_back(type_id); |
|
|
|
inputs_format.push_back(formats[builder_idex]); |
|
|
|
reshape_types.push_back(reshape_type); |
|
|
|
} |
|
|
|
dyn_input_idx++; |
|
|
|
} else if (param_type == "required") { |
|
|
|
@@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp |
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); |
|
|
|
inputs_device_type.push_back(type_id); |
|
|
|
inputs_format.push_back(formats[builder_idex]); |
|
|
|
reshape_types.push_back(reshape_type); |
|
|
|
} else { |
|
|
|
if (kernel_info_index < real_input_num) { |
|
|
|
MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; |
|
|
|
@@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp |
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); |
|
|
|
inputs_device_type.push_back(type_id); |
|
|
|
inputs_format.push_back(formats[builder_idex]); |
|
|
|
reshape_types.push_back(reshape_type); |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<Axis> reshape_type; |
|
|
|
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
reshape_types.push_back(reshape_type); |
|
|
|
} |
|
|
|
|
|
|
|
builder->SetInputReshapeType(reshape_types); |
|
|
|
@@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou |
|
|
|
MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::vector<Axis> reshape_type; |
|
|
|
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
size_t output_num = 0; |
|
|
|
if (output->param_type() == "dynamic") { |
|
|
|
if (outputs.size() > 1) { |
|
|
|
@@ -467,13 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou |
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); |
|
|
|
outputs_device_type.push_back(type_id); |
|
|
|
outputs_format.push_back(formats[builder_idex]); |
|
|
|
reshape_types.push_back(reshape_type); |
|
|
|
output_idx++; |
|
|
|
} |
|
|
|
std::vector<Axis> reshape_type; |
|
|
|
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
reshape_types.push_back(reshape_type); |
|
|
|
} |
|
|
|
|
|
|
|
builder->SetOutputReshapeType(reshape_types); |
|
|
|
|