Merge pull request !501 from chujinjin/abstract_input_tensortags/v0.3.0-alpha
| @@ -30,7 +30,8 @@ | |||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| #include "pipeline/static_analysis/prim.h" | #include "pipeline/static_analysis/prim.h" | ||||
| #include "session/session_factory.h" | #include "session/session_factory.h" | ||||
| #include "pre_activate/pass/const_input_to_attr_registry.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "pynative/base.h" | #include "pynative/base.h" | ||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| @@ -188,6 +189,117 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||||
| return std::move(result); | return std::move(result); | ||||
| } | } | ||||
| bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, | |||||
| const std::unordered_set<size_t> &input_attrs) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| auto input_names_value = op_prim->GetAttr(kAttrInputNames); | |||||
| if (input_names_value == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value); | |||||
| if (input_index >= input_names_vec.size()) { | |||||
| MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; | |||||
| } | |||||
| if (input_attrs.find(input_index) != input_attrs.end()) { | |||||
| ValuePtr value = parse::data_converter::PyDataToValue(input_object); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| auto input_name = input_names_vec[input_index]; | |||||
| op_prim->set_attr(input_name, value); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, | |||||
| std::vector<tensor::TensorPtr> *input_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| for (const auto &input_object : tuple_inputs) { | |||||
| if (!py::isinstance<tensor::Tensor>(input_object)) { | |||||
| MS_LOG(EXCEPTION) << "The input object is not a tensor!"; | |||||
| } | |||||
| auto tensor = py::cast<tensor::TensorPtr>(input_object); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| input_tensor->push_back(tensor); | |||||
| } | |||||
| op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())})); | |||||
| } | |||||
| void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); | |||||
| MS_EXCEPTION_IF_NULL(input_value); | |||||
| if (!input_value->isa<ValueTuple>()) { | |||||
| MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; | |||||
| } | |||||
| auto value_tuple = input_value->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||||
| tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple); | |||||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||||
| input_tensor->push_back(tensor_ptr); | |||||
| } | |||||
| void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, | |||||
| std::vector<tensor::TensorPtr> *input_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| tensor::TensorPtr tensor_ptr = nullptr; | |||||
| if (py::isinstance<tensor::Tensor>(input_object)) { | |||||
| tensor_ptr = py::cast<tensor::TensorPtr>(input_object); | |||||
| } else if (py::isinstance<py::float_>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32); | |||||
| } else if (py::isinstance<py::int_>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), nullptr); | |||||
| } else if (py::isinstance<py::list>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr); | |||||
| } else if (py::isinstance<py::array>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr); | |||||
| } else if (py::isinstance<py::tuple>(input_object)) { | |||||
| auto tuple_inputs = py::cast<py::tuple>(input_object); | |||||
| if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) { | |||||
| PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensor); | |||||
| } else { | |||||
| ConvertValueTupleToTensor(input_object, input_tensor); | |||||
| } | |||||
| return; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||||
| input_tensor->push_back(tensor_ptr); | |||||
| } | |||||
| void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<bool> *tensors_mask, | |||||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||||
| PrimitivePtr op_prim = op_run_info->py_primitive; | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) { | |||||
| MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size " | |||||
| << op_run_info->inputs_mask.size(); | |||||
| } | |||||
| opt::ConstInputToAttrInfoRegister reg; | |||||
| bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); | |||||
| size_t input_num = op_run_info->op_inputs.size(); | |||||
| MS_LOG(INFO) << "py input size: " << input_num; | |||||
| for (size_t index = 0; index < input_num; ++index) { | |||||
| // convert const input to attr | |||||
| if (reg_exist && | |||||
| RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { | |||||
| continue; | |||||
| } | |||||
| // convert const and tuple input to tensor | |||||
| ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors); | |||||
| // make tensors, weight : 1, data : 0 | |||||
| std::vector<bool> new_mask(input_tensors->size() - tensors_mask->size(), | |||||
| py::cast<bool>(op_run_info->inputs_mask[index])); | |||||
| tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); | |||||
| } | |||||
| } | |||||
| py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { | py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { | ||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; | MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; | ||||
| @@ -204,7 +316,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||||
| std::string graph_info = GetSingleOpGraphInfo(op_exec_info); | std::string graph_info = GetSingleOpGraphInfo(op_exec_info); | ||||
| std::vector<tensor::TensorPtr> input_tensors; | std::vector<tensor::TensorPtr> input_tensors; | ||||
| session->BuildOp(*op_exec_info, graph_info, &input_tensors); | |||||
| std::vector<bool> tensors_mask; | |||||
| ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); | |||||
| session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); | |||||
| py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); | py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); | ||||
| ms_context->set_enable_pynative_infer(false); | ms_context->set_enable_pynative_infer(false); | ||||
| *status = PYNATIVE_SUCCESS; | *status = PYNATIVE_SUCCESS; | ||||
| @@ -250,11 +250,11 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra | |||||
| } | } | ||||
| void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<bool> &tensors_mask) { | |||||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; | MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; | ||||
| // construct graph include one op | // construct graph include one op | ||||
| auto graph = ConstructSingleOpGraph(op_run_info, input_tensors); | |||||
| auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| opt::RunOpAscendBackendIRFusionOptimization(graph); | opt::RunOpAscendBackendIRFusionOptimization(graph); | ||||
| // kernel select | // kernel select | ||||
| @@ -42,7 +42,7 @@ class AscendSession : public SessionBasic { | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| void BuildGraph(GraphId) override; | void BuildGraph(GraphId) override; | ||||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| std::vector<tensor::TensorPtr> *input_tensors) override; | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<bool> &tensors_mask) override; | |||||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors) override; | const std::vector<tensor::TensorPtr> &input_tensors) override; | ||||
| @@ -133,10 +133,9 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||||
| } | } | ||||
| void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<bool> &tensors_mask) { | |||||
| // Prepare the graph | // Prepare the graph | ||||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors); | |||||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| SelectKernel(kernel_graph); | SelectKernel(kernel_graph); | ||||
| StartKernelRT(); | StartKernelRT(); | ||||
| @@ -40,7 +40,7 @@ class GPUSession : public SessionBasic { | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| std::vector<tensor::TensorPtr> *input_tensors) override; | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<bool> &tensors_mask) override; | |||||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors) override; | const std::vector<tensor::TensorPtr> &input_tensors) override; | ||||
| @@ -180,115 +180,6 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, | |||||
| const std::unordered_set<size_t> &input_attrs) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| auto input_names_value = op_prim->GetAttr(kAttrInputNames); | |||||
| if (input_names_value == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value); | |||||
| if (input_index >= input_names_vec.size()) { | |||||
| MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; | |||||
| } | |||||
| if (input_attrs.find(input_index) != input_attrs.end()) { | |||||
| ValuePtr value = parse::data_converter::PyDataToValue(input_object); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| auto input_name = input_names_vec[input_index]; | |||||
| op_prim->set_attr(input_name, value); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, | |||||
| std::vector<tensor::TensorPtr> *input_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| for (const auto &input_object : tuple_inputs) { | |||||
| if (!py::isinstance<tensor::Tensor>(input_object)) { | |||||
| MS_LOG(EXCEPTION) << "The input object is not a tensor!"; | |||||
| } | |||||
| auto tensor = py::cast<tensor::TensorPtr>(input_object); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| input_tensor->push_back(tensor); | |||||
| } | |||||
| op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())})); | |||||
| } | |||||
| void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); | |||||
| MS_EXCEPTION_IF_NULL(input_value); | |||||
| if (!input_value->isa<ValueTuple>()) { | |||||
| MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; | |||||
| } | |||||
| auto value_tuple = input_value->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||||
| tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple); | |||||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||||
| input_tensor->push_back(tensor_ptr); | |||||
| } | |||||
| void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, | |||||
| std::vector<tensor::TensorPtr> *input_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| tensor::TensorPtr tensor_ptr = nullptr; | |||||
| if (py::isinstance<tensor::Tensor>(input_object)) { | |||||
| tensor_ptr = py::cast<tensor::TensorPtr>(input_object); | |||||
| } else if (py::isinstance<py::float_>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32); | |||||
| } else if (py::isinstance<py::int_>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), nullptr); | |||||
| } else if (py::isinstance<py::list>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr); | |||||
| } else if (py::isinstance<py::array>(input_object)) { | |||||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr); | |||||
| } else if (py::isinstance<py::tuple>(input_object)) { | |||||
| auto tuple_inputs = py::cast<py::tuple>(input_object); | |||||
| if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) { | |||||
| PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensor); | |||||
| } else { | |||||
| ConvertValueTupleToTensor(input_object, input_tensor); | |||||
| } | |||||
| return; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||||
| input_tensor->push_back(tensor_ptr); | |||||
| } | |||||
| void ConvertInputPyobject(const OpRunInfo &op_run_info, const PrimitivePtr &op_prim, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, std::vector<bool> *tensors_mask) { | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||||
| MS_EXCEPTION_IF_NULL(tensors_mask); | |||||
| if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) { | |||||
| MS_LOG(EXCEPTION) << "Op input size " << op_run_info.op_inputs.size() << " should be equal to op input mask size " | |||||
| << op_run_info.inputs_mask.size(); | |||||
| } | |||||
| opt::ConstInputToAttrInfoRegister reg; | |||||
| bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info.op_name, ®); | |||||
| size_t input_num = op_run_info.op_inputs.size(); | |||||
| MS_LOG(INFO) << "py input size: " << input_num; | |||||
| for (size_t index = 0; index < input_num; ++index) { | |||||
| // convert const input to attr | |||||
| if (reg_exist && | |||||
| RunOpConvertConstInputToAttr(op_run_info.op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { | |||||
| continue; | |||||
| } | |||||
| // convert const and tuple input to tensor | |||||
| ConvertPyObjectToTensor(op_run_info.op_inputs[index], op_prim, input_tensors); | |||||
| // make tensors, weight : 1, data : 0 | |||||
| std::vector<bool> new_mask(input_tensors->size() - tensors_mask->size(), | |||||
| py::cast<bool>(op_run_info.inputs_mask[index])); | |||||
| tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); | |||||
| } | |||||
| } | |||||
| ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | ||||
| auto value_node = anf->cast<ValueNodePtr>(); | auto value_node = anf->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(value_node); | MS_EXCEPTION_IF_NULL(value_node); | ||||
| @@ -747,26 +638,22 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr | |||||
| } | } | ||||
| std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, | std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, | ||||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<bool> &tensors_mask) { | |||||
| auto graph = std::make_shared<KernelGraph>(); | auto graph = std::make_shared<KernelGraph>(); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| // set input[0] | // set input[0] | ||||
| PrimitivePtr op_prim = op_run_info.py_primitive; | PrimitivePtr op_prim = op_run_info.py_primitive; | ||||
| if (op_prim == nullptr) { | |||||
| op_prim = std::make_shared<Primitive>(op_run_info.op_name); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(op_prim); | |||||
| inputs.push_back(std::make_shared<ValueNode>(op_prim)); | inputs.push_back(std::make_shared<ValueNode>(op_prim)); | ||||
| // set input parameter | // set input parameter | ||||
| std::vector<bool> tensors_mask; | |||||
| ConvertInputPyobject(op_run_info, op_prim, input_tensors, &tensors_mask); | |||||
| MS_LOG(INFO) << "Input tensor size: " << input_tensors->size(); | |||||
| if (input_tensors->size() != tensors_mask.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " | |||||
| MS_LOG(INFO) << "Input tensor size: " << input_tensors.size(); | |||||
| if (input_tensors.size() != tensors_mask.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " | |||||
| << tensors_mask.size(); | << tensors_mask.size(); | ||||
| } | } | ||||
| for (size_t i = 0; i < input_tensors->size(); ++i) { | |||||
| auto parameter = ConstructRunOpParameter(graph, input_tensors->at(i), tensors_mask[i]); | |||||
| for (size_t i = 0; i < input_tensors.size(); ++i) { | |||||
| auto parameter = ConstructRunOpParameter(graph, input_tensors.at(i), tensors_mask[i]); | |||||
| inputs.push_back(parameter); | inputs.push_back(parameter); | ||||
| graph->MutableInputs()->push_back(parameter); | graph->MutableInputs()->push_back(parameter); | ||||
| } | } | ||||
| @@ -61,7 +61,8 @@ class SessionBasic { | |||||
| virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0; | virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0; | ||||
| virtual void BuildOp(const OpRunInfo &, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors) {} | |||||
| virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<bool> &tensors_mask) {} | |||||
| virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors) { | virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors) { | ||||
| return py::tuple(); | return py::tuple(); | ||||
| @@ -99,7 +100,8 @@ class SessionBasic { | |||||
| CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); | CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); | ||||
| // create a single run op graph | // create a single run op graph | ||||
| std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | ||||
| std::vector<tensor::TensorPtr> *input_tensor); | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<bool> &tensors_mask); | |||||
| // trans BaseRef list to py::tuple | // trans BaseRef list to py::tuple | ||||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | ||||