|
|
@@ -393,9 +393,7 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i |
|
|
ValuePtr value = parse::data_converter::PyDataToValue(input_object); |
|
|
ValuePtr value = parse::data_converter::PyDataToValue(input_object); |
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
auto input_name = input_names_vec[input_index]; |
|
|
auto input_name = input_names_vec[input_index]; |
|
|
op_prim->BeginRecordAddAttr(); |
|
|
|
|
|
op_prim->AddAttr(input_name, value); |
|
|
op_prim->AddAttr(input_name, value); |
|
|
op_prim->EndRecordAddAttr(); |
|
|
|
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
return false; |
|
|
return false; |
|
|
@@ -499,6 +497,8 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te |
|
|
|
|
|
|
|
|
opt::ConstInputToAttrInfoRegister reg; |
|
|
opt::ConstInputToAttrInfoRegister reg; |
|
|
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); |
|
|
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); |
|
|
|
|
|
|
|
|
|
|
|
op_prim->BeginRecordAddAttr(); |
|
|
size_t input_num = op_run_info->op_inputs.size(); |
|
|
size_t input_num = op_run_info->op_inputs.size(); |
|
|
for (size_t index = 0; index < input_num; ++index) { |
|
|
for (size_t index = 0; index < input_num; ++index) { |
|
|
// convert const input to attr |
|
|
// convert const input to attr |
|
|
@@ -513,6 +513,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te |
|
|
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); |
|
|
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); |
|
|
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); |
|
|
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); |
|
|
} |
|
|
} |
|
|
|
|
|
op_prim->EndRecordAddAttr(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) { |
|
|
void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) { |
|
|
|