| @@ -80,7 +80,7 @@ class AbstractShapeCreator { | |||||
| return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]}; | return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]}; | ||||
| } | } | ||||
| static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) { | static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) { | ||||
| if (device_shape.size() == 1 && (device_shape[0] == 1 || device_shape[0] % kCubeSize == 0)) { | |||||
| if (device_shape.size() == 1 && (device_shape[0] == 1 || static_cast<size_t>(device_shape[0]) % kCubeSize == 0)) { | |||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| if (device_shape.size() < 4) { | if (device_shape.size() < 4) { | ||||
| @@ -126,7 +126,7 @@ class CNodeDecoder { | |||||
| } | } | ||||
| private: | private: | ||||
| ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) { | |||||
| ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) const { | |||||
| if (type == "str") { | if (type == "str") { | ||||
| std::string value = attr_json[kJsonKeyValue]; | std::string value = attr_json[kJsonKeyValue]; | ||||
| return MakeValue(value); | return MakeValue(value); | ||||
| @@ -204,7 +204,6 @@ class CNodeDecoder { | |||||
| bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | ||||
| std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc]; | std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc]; | ||||
| AbstractBasePtr abstract(nullptr); | |||||
| if (output_descs.empty()) { | if (output_descs.empty()) { | ||||
| MS_LOG(ERROR) << "No outputs found."; | MS_LOG(ERROR) << "No outputs found."; | ||||
| return false; | return false; | ||||
| @@ -288,7 +287,7 @@ class CNodeDecoder { | |||||
| return primitive; | return primitive; | ||||
| } | } | ||||
| tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) { | |||||
| tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) const { | |||||
| auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | ||||
| switch (type_id) { | switch (type_id) { | ||||
| case kNumberTypeFloat16: | case kNumberTypeFloat16: | ||||
| @@ -435,7 +434,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js | |||||
| return DecodeFusedNodes(kernel_json); | return DecodeFusedNodes(kernel_json); | ||||
| } | } | ||||
| StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) { | |||||
| StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const { | |||||
| StitchInfo info; | StitchInfo info; | ||||
| if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { | if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { | ||||
| nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; | nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; | ||||
| @@ -451,7 +450,8 @@ StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json | |||||
| return info; | return info; | ||||
| } | } | ||||
| void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) { | |||||
| void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, | |||||
| const CNodePtr &node) const { | |||||
| std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc]; | ||||
| if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; | if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; | ||||
| std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | std::string tensor_name = output_descs[0][kJsonKeyTensorName]; | ||||
| @@ -44,8 +44,8 @@ class AkgKernelJsonDecoder { | |||||
| ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | ||||
| CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | ||||
| AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | ||||
| StitchInfo GetStitchInfo(const nlohmann::json &kernel_json); | |||||
| void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node); | |||||
| StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const; | |||||
| void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const; | |||||
| std::map<std::string, AnfNodePtr> nodes_map_; | std::map<std::string, AnfNodePtr> nodes_map_; | ||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -23,15 +23,15 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) { return x >= 0 ? x : x + static_cast<int64_t>(rank); } | |||||
| int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); } | |||||
| bool AxisNormalizer::IsReduce(const AnfNodePtr &node) { | |||||
| bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const { | |||||
| std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; | std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; | ||||
| return std::any_of(node_with_axis.begin(), node_with_axis.end(), | return std::any_of(node_with_axis.begin(), node_with_axis.end(), | ||||
| [&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); }); | [&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); }); | ||||
| } | } | ||||
| bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) { | |||||
| bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const { | |||||
| bool changed = false; | bool changed = false; | ||||
| auto todos = TopoSort(func_graph->get_return()); | auto todos = TopoSort(func_graph->get_return()); | ||||
| for (auto node : todos) { | for (auto node : todos) { | ||||
| @@ -48,8 +48,8 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) { | |||||
| bool diff = false; | bool diff = false; | ||||
| ShapeVector axis_vec; | ShapeVector axis_vec; | ||||
| if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) { | if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) { | ||||
| int64_t v1 = GetValue<int64_t>(axis); | |||||
| int64_t v2 = NormAxis(v1, rank); | |||||
| auto v1 = GetValue<int64_t>(axis); | |||||
| auto v2 = NormAxis(v1, rank); | |||||
| axis_vec.push_back(v2); | axis_vec.push_back(v2); | ||||
| diff = diff || (v1 != v2); | diff = diff || (v1 != v2); | ||||
| } else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) { | } else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) { | ||||
| @@ -61,8 +61,8 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) { | } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) { | ||||
| for (auto v : vec) { | for (auto v : vec) { | ||||
| int64_t v1 = GetValue<int64_t>(v); | |||||
| int64_t v2 = NormAxis(v1, rank); | |||||
| auto v1 = GetValue<int64_t>(v); | |||||
| auto v2 = NormAxis(v1, rank); | |||||
| axis_vec.push_back(v2); | axis_vec.push_back(v2); | ||||
| diff = diff || (v1 != v2); | diff = diff || (v1 != v2); | ||||
| } | } | ||||
| @@ -29,9 +29,9 @@ class AxisNormalizer : public Pass { | |||||
| bool Run(const FuncGraphPtr &func_graph) override; | bool Run(const FuncGraphPtr &func_graph) override; | ||||
| private: | private: | ||||
| bool Process(const FuncGraphPtr &func_graph); | |||||
| int64_t NormAxis(int64_t x, size_t rank); | |||||
| bool IsReduce(const AnfNodePtr &node); | |||||
| bool Process(const FuncGraphPtr &func_graph) const; | |||||
| int64_t NormAxis(int64_t x, size_t rank) const; | |||||
| bool IsReduce(const AnfNodePtr &node) const; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,7 +60,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che | |||||
| } | } | ||||
| circle_nodes->clear(); | circle_nodes->clear(); | ||||
| auto InputEdges = [&depend_prior](CNodePtr cnode) { | |||||
| auto InputEdges = [&depend_prior](const CNodePtr &cnode) { | |||||
| std::set<AnfNodePtr> edges; | std::set<AnfNodePtr> edges; | ||||
| auto range = depend_prior.equal_range(cnode); | auto range = depend_prior.equal_range(cnode); | ||||
| for (auto iter = range.first; iter != range.second; ++iter) { | for (auto iter = range.first; iter != range.second; ++iter) { | ||||
| @@ -30,7 +30,7 @@ std::string CommonDimInfo::ToString() { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) { | |||||
| int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const { | |||||
| nlohmann::json json_desc; | nlohmann::json json_desc; | ||||
| AnfNodePtrList nodes = {node}; | AnfNodePtrList nodes = {node}; | ||||
| DumpOption dump_option; | DumpOption dump_option; | ||||
| @@ -47,7 +47,8 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) { | |||||
| return py::cast<int>(ret); | return py::cast<int>(ret); | ||||
| } | } | ||||
| std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) { | |||||
| std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo( | |||||
| const AnfNodePtrList &nodes) const { | |||||
| nlohmann::json json_desc; | nlohmann::json json_desc; | ||||
| std::vector<AnfNodePtrList> graphs; | std::vector<AnfNodePtrList> graphs; | ||||
| std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), | std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs), | ||||
| @@ -80,7 +81,7 @@ std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFu | |||||
| return std::make_tuple(dim_infos, benefit, fusion_info); | return std::make_tuple(dim_infos, benefit, fusion_info); | ||||
| } | } | ||||
| FusionInfoPtr ParallelCostModel::ProcessFusionInfo(py::object fusion_type, py::object type_info) { | |||||
| FusionInfoPtr ParallelCostModel::ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const { | |||||
| if (!py::isinstance<py::str>(fusion_type)) { | if (!py::isinstance<py::str>(fusion_type)) { | ||||
| MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!"; | MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!"; | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ namespace opt { | |||||
| class DimInfo { | class DimInfo { | ||||
| public: | public: | ||||
| DimInfo() = default; | DimInfo() = default; | ||||
| ~DimInfo() {} | |||||
| virtual ~DimInfo() {} | |||||
| virtual std::string ToString() = 0; | virtual std::string ToString() = 0; | ||||
| }; | }; | ||||
| @@ -60,7 +60,7 @@ class FusionInfo { | |||||
| public: | public: | ||||
| FusionInfo() = default; | FusionInfo() = default; | ||||
| explicit FusionInfo(const std::string &type) : fusion_type_(type) {} | explicit FusionInfo(const std::string &type) : fusion_type_(type) {} | ||||
| ~FusionInfo() = default; | |||||
| virtual ~FusionInfo() = default; | |||||
| std::string FusionType() { return fusion_type_; } | std::string FusionType() { return fusion_type_; } | ||||
| virtual bool ExistTypeInfo() { return false; } | virtual bool ExistTypeInfo() { return false; } | ||||
| @@ -72,7 +72,7 @@ class BlockFusionInfo : public FusionInfo { | |||||
| public: | public: | ||||
| BlockFusionInfo() : FusionInfo("block_fusion") {} | BlockFusionInfo() : FusionInfo("block_fusion") {} | ||||
| ~BlockFusionInfo() = default; | ~BlockFusionInfo() = default; | ||||
| bool ExistTypeInfo() { return false; } | |||||
| bool ExistTypeInfo() override { return false; } | |||||
| }; | }; | ||||
| class BlockPipelineFusionInfo : public FusionInfo { | class BlockPipelineFusionInfo : public FusionInfo { | ||||
| @@ -80,7 +80,7 @@ class BlockPipelineFusionInfo : public FusionInfo { | |||||
| explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids) | explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids) | ||||
| : FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {} | : FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {} | ||||
| ~BlockPipelineFusionInfo() = default; | ~BlockPipelineFusionInfo() = default; | ||||
| bool ExistTypeInfo() { return true; } | |||||
| bool ExistTypeInfo() override { return true; } | |||||
| std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; } | std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; } | ||||
| private: | private: | ||||
| @@ -95,11 +95,11 @@ class ParallelCostModel { | |||||
| public: | public: | ||||
| ParallelCostModel() {} | ParallelCostModel() {} | ||||
| ~ParallelCostModel() {} | ~ParallelCostModel() {} | ||||
| int GetNodeCalAmount(const AnfNodePtr &node); | |||||
| std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes); | |||||
| int GetNodeCalAmount(const AnfNodePtr &node) const; | |||||
| std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const; | |||||
| private: | private: | ||||
| FusionInfoPtr ProcessFusionInfo(py::object fusion_type, py::object type_info); | |||||
| FusionInfoPtr ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const; | |||||
| }; | }; | ||||
| using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; | using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; | ||||