| @@ -28,6 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kType32Len = 4; | |||
| std::vector<int> Convert2Int(const std::vector<size_t> &v) { | |||
| std::vector<int> result; | |||
| (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); | |||
| @@ -264,6 +265,62 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod | |||
| } | |||
| } | |||
| template <typename T> | |||
| tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, | |||
| size_t data_length) { | |||
| MS_EXCEPTION_IF_NULL(value_tuple_ptr); | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| std::vector<T> values; | |||
| for (const auto &v : value_tuple_ptr->value()) { | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| if (v->isa<Scalar>()) { | |||
| ScalarPtr scalar = v->cast<ScalarPtr>(); | |||
| values.push_back(GetValue<T>(scalar)); | |||
| } else { | |||
| MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::vector<int> tensor_shape = {SizeToInt(values.size())}; | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; | |||
| tensor->set_device_info(device_info); | |||
| auto data_ptr = tensor->data_c(true); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| auto elem_num = values.size() * data_length; | |||
| auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num); | |||
| if (ret_code != 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; | |||
| } | |||
| return tensor; | |||
| } | |||
| tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { | |||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||
| tensor::TensorPtr tensor = nullptr; | |||
| ValuePtr v = *(value_tuple->value().begin()); | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| // Currently we only deal with the scalar tuple | |||
| if (!v->isa<Scalar>()) { | |||
| MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; | |||
| return nullptr; | |||
| } | |||
| ScalarPtr scalar = v->cast<ScalarPtr>(); | |||
| MS_EXCEPTION_IF_NULL(scalar); | |||
| if (scalar->isa<IntergerImm>()) { | |||
| tensor = CreateTensorWithValueTuple<int>(value_tuple, kInt32, kType32Len); | |||
| } else if (scalar->isa<FloatImm>()) { | |||
| tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, kType32Len); | |||
| } else { | |||
| auto type = scalar->type(); | |||
| auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); | |||
| MS_LOG(ERROR) << "Invalid scalar type: " << type_str; | |||
| return nullptr; | |||
| } | |||
| return tensor; | |||
| } | |||
| bool IsNopNode(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -135,6 +135,11 @@ void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_i | |||
| void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num, | |||
| std::vector<AnfNodePtr> *outputs); | |||
| tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, | |||
| size_t data_length); | |||
| tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); | |||
| bool IsNopNode(const AnfNodePtr &node); | |||
| void HideNopNode(session::KernelGraph *const graph); | |||
| @@ -17,10 +17,44 @@ | |||
| #include <utility> | |||
| #include "utils/utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(prim::kPrimCast->name(), {1}); | |||
| Register(prim::kPrimConv2DBackpropInput->name(), {2}); | |||
| Register(prim::kPrimConv2DBackpropFilter->name(), {2}); | |||
| Register(prim::kPrimReshape->name(), {1}); | |||
| Register(prim::kPrimReduceMax->name(), {1}); | |||
| Register(prim::kPrimReduceMin->name(), {1}); | |||
| Register(prim::kPrimReduceSum->name(), {1}); | |||
| Register(prim::kPrimReduceMean->name(), {1}); | |||
| Register(prim::kPrimGatherV2->name(), {2}); | |||
| Register(prim::kPrimTranspose->name(), {1}); | |||
| Register(prim::kPrimUnsortedSegmentSum->name(), {2}); | |||
| Register(prim::kPrimOneHot->name(), {1}); | |||
| Register(kUnsortedSegmentProdOpName, {2}); | |||
| Register(kUnsortedSegmentMinOpName, {2}); | |||
| Register(kSimpleMeanGradOpName, {1}); | |||
| Register(kMeanGradOpName, {1}); | |||
| Register(kSliceOpName, {1, 2}); | |||
| Register(kSliceGradOpName, {2, 3}); | |||
| Register(kTileOpName, {1}); | |||
| Register(kScatterNdOpName, {2}); | |||
| Register(kStridedSliceAssignOpName, {1, 2, 3}); | |||
| Register(kStridedSliceOpName, {1, 2, 3}); | |||
| Register(kStridedSliceGradOpName, {1, 2, 3, 4}); | |||
| Register(kFlattenGradOpName, {1}); | |||
| Register(kExpandDimsOpName, {1}); | |||
| Register(kSplitOpName, {0}); | |||
| Register(kTopKOpName, {1}); | |||
| Register(kSparseApplyAdagradOpName, {2}); | |||
| Register(kResizeNearestNeighborGrad, {1}); | |||
| } | |||
| ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { | |||
| static ConstInputToAttrInfoRegistry instance; | |||
| return instance; | |||
| @@ -54,7 +54,7 @@ class ConstInputToAttrInfoRegistry { | |||
| bool GetRegisterByOpName(const std::string &op_name, ConstInputToAttrInfoRegister *reg) const; | |||
| private: | |||
| ConstInputToAttrInfoRegistry() = default; | |||
| ConstInputToAttrInfoRegistry(); | |||
| ~ConstInputToAttrInfoRegistry() = default; | |||
| DISABLE_COPY_AND_ASSIGN(ConstInputToAttrInfoRegistry) | |||
| std::unordered_map<std::string, ConstInputToAttrInfoRegister> op_input_to_attr_map_; | |||
| @@ -87,37 +87,5 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); | |||
| return cnode; | |||
| } | |||
| void ConvertConstInputToAttr::Init() { | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimCast->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimConv2DBackpropInput->name(), {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimConv2DBackpropFilter->name(), {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReshape->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceMax->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceMin->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceSum->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceMean->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimGatherV2->name(), {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimTranspose->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimUnsortedSegmentSum->name(), {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimOneHot->name(), {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kUnsortedSegmentProdOpName, {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kUnsortedSegmentMinOpName, {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kSimpleMeanGradOpName, {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kMeanGradOpName, {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kSliceOpName, {1, 2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kSliceGradOpName, {2, 3}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kTileOpName, {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kScatterNdOpName, {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kStridedSliceAssignOpName, {1, 2, 3}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kStridedSliceOpName, {1, 2, 3}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kStridedSliceGradOpName, {1, 2, 3, 4}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kFlattenGradOpName, {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kExpandDimsOpName, {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kSplitOpName, {0}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kTopKOpName, {1}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kSparseApplyAdagradOpName, {2}); | |||
| ConstInputToAttrInfoRegistry::Instance().Register(kResizeNearestNeighborGrad, {1}); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -27,14 +27,11 @@ namespace opt { | |||
| class ConvertConstInputToAttr : public PatternProcessPass { | |||
| public: | |||
| explicit ConvertConstInputToAttr(bool multigraph = true) | |||
| : PatternProcessPass("convert_const_input_to_attr", multigraph) { | |||
| Init(); | |||
| } | |||
| : PatternProcessPass("convert_const_input_to_attr", multigraph) {} | |||
| ~ConvertConstInputToAttr() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| void Init(); | |||
| std::unordered_map<std::string, std::unordered_set<size_t>> op_input_attr_map_; | |||
| }; | |||
| } // namespace opt | |||
| @@ -19,69 +19,13 @@ | |||
| #include <memory> | |||
| #include "utils/graph_utils.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kType32Len = 4; | |||
| template <typename T> | |||
| tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, | |||
| size_t data_length) { | |||
| MS_EXCEPTION_IF_NULL(value_tuple_ptr); | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| std::vector<T> values; | |||
| for (const auto &v : value_tuple_ptr->value()) { | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| if (v->isa<Scalar>()) { | |||
| ScalarPtr scalar = v->cast<ScalarPtr>(); | |||
| values.push_back(GetValue<T>(scalar)); | |||
| } else { | |||
| MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::vector<int> tensor_shape = {SizeToInt(values.size())}; | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; | |||
| tensor->set_device_info(device_info); | |||
| auto data_ptr = tensor->data_c(true); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| auto elem_num = values.size() * data_length; | |||
| auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num); | |||
| if (ret_code != 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; | |||
| } | |||
| return tensor; | |||
| } | |||
| tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { | |||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||
| tensor::TensorPtr tensor = nullptr; | |||
| ValuePtr v = *(value_tuple->value().begin()); | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| // Currently we only deal with the scalar tuple | |||
| if (!v->isa<Scalar>()) { | |||
| MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; | |||
| return nullptr; | |||
| } | |||
| ScalarPtr scalar = v->cast<ScalarPtr>(); | |||
| MS_EXCEPTION_IF_NULL(scalar); | |||
| if (scalar->isa<IntergerImm>()) { | |||
| tensor = CreateTensorWithValueTuple<int>(value_tuple, kInt32, kType32Len); | |||
| } else if (scalar->isa<FloatImm>()) { | |||
| tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, kType32Len); | |||
| } else { | |||
| auto type = scalar->type(); | |||
| auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); | |||
| MS_LOG(ERROR) << "Invalid scalar type: " << type_str; | |||
| return nullptr; | |||
| } | |||
| return tensor; | |||
| } | |||
| AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||
| @@ -158,8 +158,9 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat | |||
| session->Init(ms_context->device_id()); | |||
| std::string graph_info = GetSingleOpGraphInfo(op_exec_info); | |||
| session->BuildOp(*op_exec_info, graph_info); | |||
| py::tuple result = session->RunOp(*op_exec_info, graph_info); | |||
| std::vector<tensor::TensorPtr> input_tensors; | |||
| session->BuildOp(*op_exec_info, graph_info, &input_tensors); | |||
| py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); | |||
| ms_context->set_enable_pynative_infer(false); | |||
| *status = PYNATIVE_SUCCESS; | |||
| return result; | |||
| @@ -204,10 +204,12 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { | |||
| void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; | |||
| // construct graph include one op | |||
| auto graph = ConstructSingleOpGraph(op_run_info); | |||
| auto graph = ConstructSingleOpGraph(op_run_info, input_tensors); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| opt::RunOpAscendBackendIRFusionOptimization(graph); | |||
| // kernel select | |||
| @@ -222,14 +224,12 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||
| run_op_graphs_[graph_info] = graph; | |||
| } | |||
| py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { | |||
| py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| auto graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; | |||
| // malloc mem | |||
| std::vector<tensor::TensorPtr> input_tensors = {}; | |||
| std::vector<bool> tensors_mask = {}; | |||
| ToTensorPtr(op_run_info, &input_tensors, &tensors_mask); | |||
| RunOpMemoryAlloc(input_tensors, graph.get()); | |||
| // load input data to device | |||
| LoadInputData(graph, input_tensors); | |||
| @@ -41,8 +41,10 @@ class AscendSession : public SessionBasic { | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildGraph(GraphId) override; | |||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; | |||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; | |||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors) override; | |||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) override; | |||
| // set parameters of final graph | |||
| GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override; | |||
| @@ -132,9 +132,11 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| } | |||
| } | |||
| void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { | |||
| void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||
| // Prepare the graph | |||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info); | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| SelectKernel(kernel_graph); | |||
| StartKernelRT(); | |||
| @@ -142,12 +144,10 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in | |||
| run_op_graphs_[graph_info] = kernel_graph; | |||
| } | |||
| py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { | |||
| py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<tensor::TensorPtr> input_tensors = {}; | |||
| std::vector<bool> tensors_mask = {}; | |||
| ToTensorPtr(op_run_info, &input_tensors, &tensors_mask); | |||
| RunOpAllocateMemory(input_tensors, kernel_graph.get()); | |||
| // Execute the computation | |||
| LoadInputData(kernel_graph, input_tensors); | |||
| @@ -39,8 +39,10 @@ class GPUSession : public SessionBasic { | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; | |||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; | |||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors) override; | |||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) override; | |||
| private: | |||
| void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| @@ -17,6 +17,7 @@ | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "ir/manager.h" | |||
| #include "operator/ops.h" | |||
| @@ -26,6 +27,7 @@ | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "pre_activate/common/common_backend_optimization.h" | |||
| #include "pre_activate/pass/const_input_to_attr_registry.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "common/utils.h" | |||
| #include "ir/dtype.h" | |||
| @@ -178,56 +180,113 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| return ret; | |||
| } | |||
| std::string FindOpInputParameterType(const std::string &op_name, kernel::OpImplyType implyType, size_t index) { | |||
| std::string para_type; | |||
| auto op_info = kernel::OpLib::FindOp(op_name, implyType); | |||
| if (op_info == nullptr) { | |||
| return para_type; | |||
| bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, | |||
| const std::unordered_set<size_t> &input_attrs) { | |||
| MS_EXCEPTION_IF_NULL(op_prim); | |||
| auto input_names_value = op_prim->GetAttr(kAttrInputNames); | |||
| if (input_names_value == nullptr) { | |||
| return false; | |||
| } | |||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value); | |||
| if (input_index >= input_names_vec.size()) { | |||
| MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; | |||
| } | |||
| auto op_inputs_info_vec = op_info->inputs_ptr(); | |||
| if (index >= op_inputs_info_vec.size()) { | |||
| return para_type; | |||
| if (input_attrs.find(input_index) != input_attrs.end()) { | |||
| ValuePtr value = parse::data_converter::PyDataToValue(input_object); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto input_name = input_names_vec[input_index]; | |||
| op_prim->set_attr(input_name, value); | |||
| return true; | |||
| } | |||
| auto op_io_info = op_inputs_info_vec[index]; | |||
| MS_EXCEPTION_IF_NULL(op_io_info); | |||
| para_type = op_io_info->param_type(); | |||
| return para_type; | |||
| return false; | |||
| } | |||
| void RunOpConvertConstInputToAttr(const OpRunInfo &op_run_info, const std::shared_ptr<CNode> &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_inputs = op_run_info.op_inputs; | |||
| // get input names vector from attrs | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto input_names_value = primitive->GetAttr(kAttrInputNames); | |||
| if (input_names_value == nullptr) { | |||
| void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, | |||
| std::vector<tensor::TensorPtr> *input_tensor) { | |||
| MS_EXCEPTION_IF_NULL(op_prim); | |||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||
| for (const auto &input_object : tuple_inputs) { | |||
| if (!py::isinstance<tensor::Tensor>(input_object)) { | |||
| MS_LOG(EXCEPTION) << "The input object is not a tensor!"; | |||
| } | |||
| auto tensor = py::cast<tensor::TensorPtr>(input_object); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| input_tensor->push_back(tensor); | |||
| } | |||
| op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())})); | |||
| } | |||
| void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensor) { | |||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||
| ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); | |||
| MS_EXCEPTION_IF_NULL(input_value); | |||
| if (!input_value->isa<ValueTuple>()) { | |||
| MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; | |||
| } | |||
| auto value_tuple = input_value->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||
| tensor::TensorPtr tensor_ptr = nullptr; | |||
| tensor_ptr = opt::CreateTupleTensor(value_tuple); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| input_tensor->push_back(tensor_ptr); | |||
| } | |||
| void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, | |||
| std::vector<tensor::TensorPtr> *input_tensor) { | |||
| MS_EXCEPTION_IF_NULL(op_prim); | |||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||
| tensor::TensorPtr tensor_ptr = nullptr; | |||
| if (py::isinstance<tensor::Tensor>(input_object)) { | |||
| tensor_ptr = py::cast<tensor::TensorPtr>(input_object); | |||
| } else if (py::isinstance<py::float_>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32); | |||
| } else if (py::isinstance<py::int_>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), nullptr); | |||
| } else if (py::isinstance<py::list>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr); | |||
| } else if (py::isinstance<py::array>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr); | |||
| } else if (py::isinstance<py::tuple>(input_object)) { | |||
| auto tuple_inputs = py::cast<py::tuple>(input_object); | |||
| if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) { | |||
| PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensor); | |||
| } else { | |||
| ConvertValueTupleToTensor(input_object, input_tensor); | |||
| } | |||
| return; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; | |||
| } | |||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value); | |||
| // convert const input to attr | |||
| size_t input_num = op_inputs.size(); | |||
| if (input_num != input_names_vec.size()) { | |||
| MS_LOG(EXCEPTION) << "input name number " << input_names_vec.size() << "is not equal to input value number " | |||
| << input_num; | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| input_tensor->push_back(tensor_ptr); | |||
| } | |||
| void ConvertInputPyobject(const OpRunInfo &op_run_info, const PrimitivePtr &op_prim, | |||
| std::vector<tensor::TensorPtr> *input_tensors, std::vector<bool> *tensors_mask) { | |||
| MS_EXCEPTION_IF_NULL(op_prim); | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| MS_EXCEPTION_IF_NULL(tensors_mask); | |||
| if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) { | |||
| MS_LOG(EXCEPTION) << "Op input size " << op_run_info.op_inputs.size() << " should be equal to op input mask size " | |||
| << op_run_info.inputs_mask.size(); | |||
| } | |||
| opt::ConstInputToAttrInfoRegister reg; | |||
| bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info.op_name, ®); | |||
| size_t input_num = op_run_info.op_inputs.size(); | |||
| MS_LOG(INFO) << "py input size: " << input_num; | |||
| for (size_t index = 0; index < input_num; ++index) { | |||
| // skip tensor | |||
| if (py::isinstance<tensor::Tensor>(op_inputs[index])) { | |||
| continue; | |||
| } | |||
| // convert to attr | |||
| auto para_type = FindOpInputParameterType(op_run_info.op_name, kernel::OpImplyType::kTBE, index); | |||
| if (!para_type.empty() && para_type == kAttrDynInput) { | |||
| auto tuple_inputs = py::cast<py::tuple>(op_inputs[index]); | |||
| primitive->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())})); | |||
| // convert const input to attr | |||
| if (reg_exist && | |||
| RunOpConvertConstInputToAttr(op_run_info.op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { | |||
| continue; | |||
| } | |||
| ValuePtr value = parse::data_converter::PyDataToValue(op_inputs[index]); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto input_name = input_names_vec[index]; | |||
| // set the input node as attr of the cnode, key is name of input node,value is input node's value | |||
| primitive->set_attr(input_name, value); | |||
| // convert const and tuple input to tensor | |||
| ConvertPyObjectToTensor(op_run_info.op_inputs[index], op_prim, input_tensors); | |||
| // make tensors, weight : 1, data : 0 | |||
| std::vector<bool> new_mask(input_tensors->size() - tensors_mask->size(), | |||
| py::cast<bool>(op_run_info.inputs_mask[index])); | |||
| tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); | |||
| } | |||
| } | |||
| @@ -638,40 +697,6 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| summary_callback_(0, params_list); | |||
| } | |||
| void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor::TensorPtr> *inputs, | |||
| std::vector<bool> *tensor_mask) { | |||
| MS_EXCEPTION_IF_NULL(inputs); | |||
| MS_EXCEPTION_IF_NULL(tensor_mask); | |||
| if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) { | |||
| MS_LOG(EXCEPTION) << "Op input size " << op_run_info.op_inputs.size() << " should be equal to op input mask size " | |||
| << op_run_info.inputs_mask.size(); | |||
| } | |||
| size_t input_num = op_run_info.op_inputs.size(); | |||
| // get tensors from op_inputs | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| tensor::TensorPtr tensor_ptr = nullptr; | |||
| auto param_type = FindOpInputParameterType(op_run_info.op_name, kernel::OpImplyType::kTBE, i); | |||
| if (py::isinstance<tensor::Tensor>(op_run_info.op_inputs[i])) { | |||
| tensor_ptr = py::cast<tensor::TensorPtr>(op_run_info.op_inputs[i]); | |||
| } else if (!param_type.empty() && param_type == kAttrDynInput) { | |||
| auto tuple_inputs = py::cast<py::tuple>(op_run_info.op_inputs[i]); | |||
| for (auto &&tuple_input : tuple_inputs) { | |||
| tensor_ptr = py::cast<tensor::TensorPtr>(tuple_input); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| inputs->push_back(tensor_ptr); | |||
| tensor_mask->push_back(py::cast<bool>(op_run_info.inputs_mask[i])); | |||
| } | |||
| continue; | |||
| } else if (op_run_info.op_name == kApplyMomentumOpName && py::isinstance<py::float_>(op_run_info.op_inputs[i])) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(op_run_info.op_inputs[i]), kFloat32); | |||
| } | |||
| if (tensor_ptr != nullptr) { | |||
| inputs->push_back(tensor_ptr); | |||
| tensor_mask->push_back(py::cast<bool>(op_run_info.inputs_mask[i])); | |||
| } | |||
| } | |||
| } | |||
| CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> output_args; | |||
| @@ -724,30 +749,27 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info) { | |||
| std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| auto graph = std::make_shared<KernelGraph>(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) { | |||
| MS_LOG(EXCEPTION) << "op_run_info inputs.size" << op_run_info.op_inputs.size() | |||
| << " should be equal to parameter_mask.size " << op_run_info.inputs_mask.size(); | |||
| } | |||
| // set input[0] | |||
| if (op_run_info.py_primitive == nullptr) { | |||
| inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(op_run_info.op_name))); | |||
| } else { | |||
| inputs.push_back(std::make_shared<ValueNode>(op_run_info.py_primitive)); | |||
| PrimitivePtr op_prim = op_run_info.py_primitive; | |||
| if (op_prim == nullptr) { | |||
| op_prim = std::make_shared<Primitive>(op_run_info.op_name); | |||
| } | |||
| inputs.push_back(std::make_shared<ValueNode>(op_prim)); | |||
| // set input parameter | |||
| std::vector<tensor::TensorPtr> input_tensors; | |||
| std::vector<bool> tensors_mask; | |||
| ToTensorPtr(op_run_info, &input_tensors, &tensors_mask); | |||
| MS_LOG(INFO) << "Input tensor size" << input_tensors.size(); | |||
| if (input_tensors.size() != tensors_mask.size()) { | |||
| MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " | |||
| ConvertInputPyobject(op_run_info, op_prim, input_tensors, &tensors_mask); | |||
| MS_LOG(INFO) << "Input tensor size: " << input_tensors->size(); | |||
| if (input_tensors->size() != tensors_mask.size()) { | |||
| MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " | |||
| << tensors_mask.size(); | |||
| } | |||
| for (size_t i = 0; i < input_tensors.size(); ++i) { | |||
| auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); | |||
| for (size_t i = 0; i < input_tensors->size(); ++i) { | |||
| auto parameter = ConstructRunOpParameter(graph, input_tensors->at(i), tensors_mask[i]); | |||
| inputs.push_back(parameter); | |||
| graph->MutableInputs()->push_back(parameter); | |||
| } | |||
| @@ -756,8 +778,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // set abstract,which include inferred shapes and types | |||
| cnode->set_abstract(op_run_info.abstract); | |||
| // set const input to attr if value is not a tensor,such as scalar or tuple | |||
| RunOpConvertConstInputToAttr(op_run_info, cnode); | |||
| // set execution order | |||
| std::vector<CNodePtr> exe_order = {cnode}; | |||
| graph->set_execution_order(exe_order); | |||
| @@ -61,9 +61,11 @@ class SessionBasic { | |||
| virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0; | |||
| virtual void BuildOp(const OpRunInfo &, const GraphInfo &) {} | |||
| virtual void BuildOp(const OpRunInfo &, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors) {} | |||
| virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &) { return py::tuple(); } | |||
| virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| return py::tuple(); | |||
| } | |||
| virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | |||
| @@ -96,10 +98,8 @@ class SessionBasic { | |||
| void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph); | |||
| CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); | |||
| // create a single run op graph | |||
| std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info); | |||
| // get tensors from op inputs | |||
| void ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor::TensorPtr> *inputs, | |||
| std::vector<bool> *tensor_mask); | |||
| std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | |||
| std::vector<tensor::TensorPtr> *input_tensor); | |||
| // trans BaseRef list to py::tuple | |||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | |||