| @@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) { | |||
| template <typename S, typename T> | |||
| void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0); | |||
| target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | |||
| source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); | |||
| } | |||
| template <typename S, typename T> | |||
| @@ -45,7 +45,6 @@ bool CastCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| S *input = reinterpret_cast<S *>(inputs[0]->addr); | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| Cast<S, T>(input, output, lens); | |||
| return true; | |||
| @@ -27,7 +27,7 @@ void MaximumGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| dy_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| if (!x_shape_.size() || !y_shape_.size() || !dout_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "Input NULL"; | |||
| } | |||
| @@ -36,6 +36,11 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") | |||
| -Wno-overloaded-virtual -Wno-unused-const-variable -Wno-pessimizing-move") | |||
| endif() | |||
| if(ENABLE_CPU) | |||
| file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") | |||
| list(APPEND _PREACTIVATE_SRC_LIST ${_CPU_SRC_LIST}) | |||
| endif() | |||
| set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS | |||
| SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) | |||
| add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST}) | |||
| @@ -0,0 +1,174 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/cpu/insert_cast_cpu.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::string input_format = format; | |||
| std::string output_format = format; | |||
| CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input}); | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| // set kernel build info | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({input_format}); | |||
| builder.SetOutputsFormat({output_format}); | |||
| builder.SetInputsDeviceType({input_type}); | |||
| builder.SetOutputsDeviceType({output_type}); | |||
| // if kernel info is null , it remarks this function is running ut | |||
| if (cast->kernel_info() == nullptr) { | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| cast->set_kernel_info(kernel_info); | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | |||
| return cast; | |||
| } | |||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<bool> &need_insert_cast) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { | |||
| AnfNodePtr replace_node = nullptr; | |||
| const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | |||
| auto idx = NewValueNode(SizeToLong(output_idx)); | |||
| MS_EXCEPTION_IF_NULL(idx); | |||
| auto imm = std::make_shared<Int64Imm>(output_idx); | |||
| idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm)); | |||
| auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); | |||
| AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); | |||
| if (need_insert_cast[output_idx]) { | |||
| const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); | |||
| const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); | |||
| if (infer_type != device_type) { | |||
| replace_node = | |||
| AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, infer_type, origin_shape, infer_type); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| replace_node->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) { | |||
| kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return cnode; | |||
| } | |||
| void InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| size_t in_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| auto mng = kernel_graph->manager(); | |||
| for (size_t input_index = 0; input_index < in_num; ++input_index) { | |||
| auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index); | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | |||
| const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); | |||
| if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); infer_type != device_type) { | |||
| auto cast = | |||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, infer_type, device_type, origin_shape, device_type); | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| cast->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); | |||
| mng->Replace(cur_input, cast); | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<bool> &need_insert_cast) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | |||
| return cnode; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode->Type()); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| // Single output | |||
| if (!cnode->Type()->isa<Tuple>()) { | |||
| if (!need_insert_cast[0]) { | |||
| return cnode; | |||
| } | |||
| const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); | |||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); | |||
| const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); | |||
| AnfNodePtr replace_node = cnode; | |||
| if (infer_type != device_type) { | |||
| replace_node = | |||
| AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, infer_type, origin_shape, infer_type); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| replace_node->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) { | |||
| kernel_graph->ReplaceInternalOutput(cnode, replace_node); | |||
| } | |||
| } | |||
| return replace_node; | |||
| } | |||
| // Multiple output | |||
| return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); | |||
| } | |||
| } // namespace | |||
| const BaseRef InsertCastCPU::DefinePattern() const { | |||
| VarPtr V = std::make_shared<CondVar>(UnVisited); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| const AnfNodePtr InsertCastCPU::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| // process input | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| InsertCastForInput(func_graph, cnode); | |||
| // process output | |||
| return InsertCastForOutput(func_graph, cnode, std::vector<bool>(AnfAlgo::GetOutputTensorNum(cnode), true)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertCastCPU : public PatternProcessPass { | |||
| public: | |||
| explicit InsertCastCPU(bool multigraph = true) : PatternProcessPass("insert_cast_cpu", multigraph) {} | |||
| ~InsertCastCPU() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H | |||
| @@ -27,7 +27,9 @@ | |||
| #include "runtime/device/cpu/kernel_select_cpu.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/pass_manager.h" | |||
| #include "backend/optimizer/cpu/insert_cast_cpu.h" | |||
| #include "backend/optimizer/pass/replace_node_by_proxy.h" | |||
| #include "backend/optimizer/pass/erase_visit_attr.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/dump_proto.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| @@ -61,9 +63,21 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos | |||
| void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| std::string pass_name = "replace_node_by_proxy"; | |||
| pass_name.append(std::to_string(graph_sum_)); | |||
| pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name)); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { | |||
| AssignParamKey(kernel_graph); | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| std::string pass_name = "replace_node_by_proxy"; | |||
| pass_name.append(std::to_string(graph_sum_)); | |||
| pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name)); | |||
| } | |||
| } | |||
| #endif | |||
| pm->AddPass(std::make_shared<opt::InsertCastCPU>()); | |||
| pm->AddPass(std::make_shared<opt::EraseVisitAttr>()); | |||
| MS_LOG(INFO) << "insert cast pass"; | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -77,14 +91,8 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| graph->UpdateGraphDynamicAttr(); | |||
| MS_LOG(INFO) << "Set kernel info"; | |||
| SetKernelInfo(graph.get()); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::PSContext::instance()->is_ps_mode()) { | |||
| AssignParamKey(graph); | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| Optimize(graph); | |||
| } | |||
| } | |||
| #endif | |||
| MS_LOG(INFO) << "Set kernel info end"; | |||
| Optimize(graph); | |||
| MS_LOG(INFO) << "Build kernel"; | |||
| BuildKernel(graph.get()); | |||
| @@ -168,6 +176,7 @@ void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap | |||
| auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| SetKernelInfo(kernel_graph.get()); | |||
| Optimize(kernel_graph); | |||
| BuildKernel(kernel_graph.get()); | |||
| run_op_graphs_[graph_info] = kernel_graph; | |||
| } | |||
| @@ -35,21 +35,6 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { | |||
| return false; | |||
| } | |||
| void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes, | |||
| const CNodePtr kernel_node) { | |||
| for (auto &input_index : input_not_cnode_indexes) { | |||
| auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| std::vector<TypeId> output_types; | |||
| output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| builder->SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder->SetOutputsDeviceType(output_types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); | |||
| } | |||
| } | |||
| void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats, | |||
| std::vector<TypeId> *output_types) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| @@ -142,35 +127,11 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, | |||
| int format_matched_num = 0; | |||
| auto input_num = input_types.size(); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), | |||
| [i](size_t index) { return index == i; }); | |||
| bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); | |||
| if (have_cnode_input && is_not_cnode_idx) { | |||
| data_type_matched_num++; | |||
| format_matched_num++; | |||
| continue; | |||
| } | |||
| if (is_not_cnode_idx) { | |||
| if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) { | |||
| MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first | |||
| << ", actual input dtype:" << input_types[i]; | |||
| } else { | |||
| data_type_matched_num++; | |||
| } | |||
| format_matched_num++; | |||
| continue; | |||
| } | |||
| if (kernel_attr.GetInputAttr(i).first != input_types[i]) { | |||
| if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) { | |||
| MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first | |||
| << ", actual input dtype:" << input_types[i]; | |||
| } else { | |||
| data_type_matched_num++; | |||
| } | |||
| if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { | |||
| MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second | |||
| << ", actual input format:" << input_formats[i]; | |||
| } else { | |||
| format_matched_num++; | |||
| } | |||
| } | |||
| @@ -320,9 +281,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| (matched.first || input_types.size() == input_not_cnode_indexes.size())) { | |||
| MS_LOG(INFO) << "Input format and dtype is matched"; | |||
| GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types); | |||
| UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node); | |||
| for (auto &input_index : input_not_cnode_indexes) { | |||
| input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first; | |||
| for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) { | |||
| input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first; | |||
| } | |||
| } | |||
| SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get()); | |||