The "throw" statement is not allowed in mindspore project (codedex check), so we remove the self-define exception and replace with MS_LOG(EXCEPTION). In GraphKernelExpanders, we check the return value instead. The rollback function in ArithmeticSimplify / TrnasformOpOptimizer is not supported now. what's more, changed the c++ op expanders from .h files to .cc files, the OpExpanderRegister is called in each .cc file, likes the operator registers in mindspore.tags/v1.5.0-rc1
| @@ -158,12 +158,10 @@ PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) { | |||
| } | |||
| cur_node->AddInput(BuildTree(op_inputs)); | |||
| return cur_node; | |||
| } else { | |||
| return std::make_shared<PatternNode>(pattern_str); | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -276,7 +274,6 @@ bool DfsMatchGraph(const graphkernel::NodePtr &tmp_node, const PatternNodePtr &t | |||
| return false; | |||
| } | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) { | |||
| if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) { | |||
| @@ -387,7 +384,6 @@ class ExtraReduce1PatternTree : public PatternTree { | |||
| for (auto &i : GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second)) { | |||
| axis_set.insert(i); | |||
| } | |||
| } else { | |||
| auto first_axis = GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second); | |||
| auto second_axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second); | |||
| @@ -538,7 +534,6 @@ std::unordered_map<std::string, std::vector<PatternTreePtr>> GetExpressions() { | |||
| std::unordered_set<std::string> enable_ids{flags.enable_simplify_exprs_only.begin(), | |||
| flags.enable_simplify_exprs_only.end()}; | |||
| std::unordered_set<std::string> disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()}; | |||
| for (auto &e : expressions) { | |||
| if (!enable_ids.empty()) { | |||
| if (enable_ids.count(std::to_string(e.id)) == 0) continue; | |||
| @@ -640,33 +635,29 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { | |||
| expressions_map_ = GetExpressions(); | |||
| for (auto node : func_graph->GetOrderedCnodes()) { | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| try { | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph); | |||
| bool find_pattern = true; | |||
| bool change_anf_graph = false; | |||
| while (find_pattern) { | |||
| find_pattern = false; | |||
| find_pattern = DoArithmeticTrans(lg) || find_pattern; | |||
| find_pattern = DoConstantFold(lg) || find_pattern; | |||
| change_anf_graph = change_anf_graph || find_pattern; | |||
| } | |||
| if (!change_anf_graph) continue; | |||
| ReorganizeEmptyGraph(lg); | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| EliminateRedundantParameters(new_funcgraph, &inputs); | |||
| auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| do_simplify = true; | |||
| } catch (const graphkernel::GKException &e) { | |||
| MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph"; | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph); | |||
| bool find_pattern = true; | |||
| bool change_anf_graph = false; | |||
| while (find_pattern) { | |||
| find_pattern = false; | |||
| find_pattern = DoArithmeticTrans(lg) || find_pattern; | |||
| find_pattern = DoConstantFold(lg) || find_pattern; | |||
| change_anf_graph = change_anf_graph || find_pattern; | |||
| } | |||
| if (!change_anf_graph) continue; | |||
| ReorganizeEmptyGraph(lg); | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| EliminateRedundantParameters(new_funcgraph, &inputs); | |||
| auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| do_simplify = true; | |||
| } | |||
| } | |||
| return do_simplify; | |||
| @@ -13,14 +13,12 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/optimizer/graph_kernel/expanders/expander_factory.h" | |||
| #include "backend/optimizer/graph_kernel/expanders/utils.h" | |||
| namespace mindspore { | |||
| @@ -34,7 +32,8 @@ class BiasAdd : public OpExpander { | |||
| support_format->AddFormat({kOpFormat_NCHW, kOpFormat_DEFAULT}); | |||
| support_format->AddFormat({kOpFormat_NHWC, kOpFormat_DEFAULT}); | |||
| validators_.emplace_back(std::move(support_format)); | |||
| validators_.emplace_back(new CheckAttr({"format"})); | |||
| auto attrs = std::initializer_list<std::string>{"format"}; | |||
| validators_.emplace_back(std::make_unique<CheckAttr>(attrs)); | |||
| } | |||
| ~BiasAdd() = default; | |||
| NodePtrList Expand() override { | |||
| @@ -42,19 +41,19 @@ class BiasAdd : public OpExpander { | |||
| auto input_x = inputs[0]; | |||
| auto input_y = inputs[1]; | |||
| if (input_x->format == kOpFormat_NCHW) { | |||
| input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, {1, 2}))}}); | |||
| input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, {1, 2}))}}); | |||
| } else if (input_x->format == kOpFormat_DEFAULT) { | |||
| auto data_format = GetValue<std::string>(attrs_["format"]); | |||
| size_t channel_idx = (data_format == kOpFormat_NHWC) ? input_x->shape.size() - 1 : 1; | |||
| std::vector<int64_t> axis(input_x->shape.size() - channel_idx - 1, -1); | |||
| if (!axis.empty()) { | |||
| input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, axis))}}); | |||
| input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, axis))}}); | |||
| } | |||
| } | |||
| return {gb.Emit("Add", {input_x, input_y})}; | |||
| } | |||
| }; | |||
| OP_EXPANDER_REGISTER("BiasAdd", BiasAdd); | |||
| } // namespace expanders | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_ | |||
| @@ -22,22 +22,15 @@ | |||
| #include <memory> | |||
| #include "backend/optimizer/graph_kernel/expanders/utils.h" | |||
| #include "backend/optimizer/graph_kernel/expanders/reshape.h" | |||
| #include "backend/optimizer/graph_kernel/expanders/bias_add.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace expanders { | |||
| #define OP_EXPANDER_CREATOR(cls) []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); } | |||
| class OpExpanderFactory { | |||
| public: | |||
| static OpExpanderFactory &Instance() { | |||
| static std::unique_ptr<OpExpanderFactory> instance = nullptr; | |||
| if (instance == nullptr) { | |||
| instance.reset(new OpExpanderFactory()); | |||
| } | |||
| return *instance; | |||
| static OpExpanderFactory instance; | |||
| return instance; | |||
| } | |||
| std::shared_ptr<OpExpander> GetExpander(const std::string &op) { | |||
| if (auto iter = creators.find(op); iter != creators.end()) { | |||
| @@ -49,16 +42,24 @@ class OpExpanderFactory { | |||
| } | |||
| ~OpExpanderFactory() = default; | |||
| private: | |||
| using RegFunc = std::function<std::shared_ptr<OpExpander>()>; | |||
| void Register(std::string &&op, RegFunc &&func) { creators.insert({op, func}); } | |||
| OpExpanderFactory() { | |||
| Register("BiasAdd", OP_EXPANDER_CREATOR(expanders::BiasAdd)); | |||
| Register("ExpandDims", OP_EXPANDER_CREATOR(expanders::ExpandDims)); | |||
| } | |||
| void Register(const std::string &op, const RegFunc &func) { creators[op] = func; } | |||
| private: | |||
| std::unordered_map<std::string, RegFunc> creators; | |||
| }; | |||
| class OpExpanderRegister { | |||
| public: | |||
| OpExpanderRegister(const std::string &name, const OpExpanderFactory::RegFunc &func) { | |||
| OpExpanderFactory::Instance().Register(name, func); | |||
| } | |||
| ~OpExpanderRegister() = default; | |||
| }; | |||
| #define OP_EXPANDER_REGISTER(name, cls) \ | |||
| static const OpExpanderRegister g_##cls##_expander_reg( \ | |||
| name, []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); }) | |||
| } // namespace expanders | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -13,25 +13,24 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include "backend/optimizer/graph_kernel/expanders/utils.h" | |||
| #include "backend/optimizer/graph_kernel/expanders/expander_factory.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace expanders { | |||
| class ExpandDims : public OpExpander { | |||
| public: | |||
| ExpandDims() { validators_.emplace_back(new CheckAttr({"axis"})); } | |||
| ~ExpandDims() {} | |||
| ExpandDims() { | |||
| std::initializer_list<std::string> attrs{"axis"}; | |||
| validators_.emplace_back(std::make_unique<CheckAttr>(attrs)); | |||
| } | |||
| ~ExpandDims() = default; | |||
| NodePtrList Expand() override { | |||
| const auto &inputs = gb.Get()->inputs(); | |||
| auto &input_x = inputs[0]; | |||
| const auto &input_x = inputs[0]; | |||
| auto shape = MakeValue(ExpandDims::InferShape(input_x->shape, GetAxisList(this->attrs_["axis"]))); | |||
| auto result = gb.Emit("Reshape", {input_x}, {{"shape", shape}}); | |||
| return {result}; | |||
| @@ -42,9 +41,7 @@ class ExpandDims : public OpExpander { | |||
| for (auto x : axis) { | |||
| int64_t rank = static_cast<int64_t>(new_shape.size()); | |||
| if (x > rank || x < -rank - 1) { | |||
| std::ostringstream oss; | |||
| oss << "ExpandDims axis " << x << " is out of range of size " << new_shape.size(); | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(EXCEPTION) << "ExpandDims axis " << x << " is out of range of size " << new_shape.size(); | |||
| } | |||
| if (x >= 0) { | |||
| new_shape.insert(new_shape.begin() + x, 1LL); | |||
| @@ -55,7 +52,11 @@ class ExpandDims : public OpExpander { | |||
| return new_shape; | |||
| } | |||
| }; | |||
| OP_EXPANDER_REGISTER("ExpandDims", ExpandDims); | |||
| ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) { | |||
| return ExpandDims::InferShape(shape, axis); | |||
| } | |||
| } // namespace expanders | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_ | |||
| @@ -31,27 +31,31 @@ graphkernel::LiteGraphPtr OpExpander::Run(const BaseInfoList &inputs, const Base | |||
| this->outputs_info_ = outputs; | |||
| this->attrs_ = attrs; | |||
| this->processor_ = processor; | |||
| for (const auto &v : validators_) { | |||
| v->Check(*this); | |||
| if (std::any_of(validators_.begin(), validators_.end(), | |||
| [this](const std::unique_ptr<Validator> &v) { return !(v->Check(*this)); })) { | |||
| return nullptr; | |||
| } | |||
| if (!this->CheckInputs()) { | |||
| return nullptr; | |||
| } | |||
| this->CheckInputs(); | |||
| for (auto &inp : inputs) { | |||
| (void)gb.Parameter(inp); | |||
| } | |||
| auto result = this->Expand(); | |||
| gb.SetOutputs(result); | |||
| this->CheckOutputs(); | |||
| if (!this->CheckOutputs()) { | |||
| return nullptr; | |||
| } | |||
| return gb.Get(); | |||
| } | |||
| void OpExpander::CheckOutputs() { | |||
| bool OpExpander::CheckOutputs() { | |||
| // check the output shape/type/format are same as the original basic node's output. | |||
| const NodePtrList &outputs = gb.Get()->GetOutputs(); | |||
| if (outputs.size() != this->outputs_info_.size()) { | |||
| std::ostringstream oss; | |||
| oss << "the output num was not equal to the original output num : " << outputs.size() << " vs " | |||
| << outputs_info_.size(); | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << "the output num was not equal to the original output num : " << outputs.size() << " vs " | |||
| << outputs_info_.size(); | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < outputs.size(); i++) { | |||
| if (outputs[i]->shape != outputs_info_[i].shape) { | |||
| @@ -65,21 +69,21 @@ void OpExpander::CheckOutputs() { | |||
| oss << s << ","; | |||
| } | |||
| oss << "]"; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << oss.str(); | |||
| return false; | |||
| } | |||
| if (outputs[i]->type != outputs_info_[i].type) { | |||
| std::ostringstream oss; | |||
| oss << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: [" | |||
| << outputs_info_[i].type << "]"; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: [" | |||
| << outputs_info_[i].type << "]"; | |||
| return false; | |||
| } | |||
| if (outputs[i]->format != outputs_info_[i].format) { | |||
| std::ostringstream oss; | |||
| oss << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: [" | |||
| << outputs_info_[i].format << "]"; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: [" | |||
| << outputs_info_[i].format << "]"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<int64_t> GetAxisList(const ValuePtr &value) { | |||
| @@ -37,9 +37,9 @@ class OpExpander { | |||
| virtual ~OpExpander() = default; | |||
| protected: | |||
| virtual void CheckInputs() {} | |||
| virtual bool CheckInputs() { return true; } | |||
| virtual NodePtrList Expand() = 0; | |||
| void CheckOutputs(); | |||
| bool CheckOutputs(); | |||
| graphkernel::LiteGraph::GraphBuilder gb; | |||
| std::string op_; | |||
| @@ -57,37 +57,36 @@ class OpExpander { | |||
| class Validator { | |||
| public: | |||
| virtual void Check(const OpExpander &e) = 0; | |||
| virtual bool Check(const OpExpander &e) = 0; | |||
| }; | |||
| class CheckAllFormatsSame : public Validator { | |||
| public: | |||
| void Check(const OpExpander &e) override { | |||
| if (e.inputs_info_.empty()) return; | |||
| bool Check(const OpExpander &e) override { | |||
| if (e.inputs_info_.empty()) return true; | |||
| const auto &fmt_0 = e.inputs_info_[0].format; | |||
| for (size_t i = 1; i < e.inputs_info_.size(); i++) { | |||
| if (e.inputs_info_[i].format != fmt_0) { | |||
| std::ostringstream oss; | |||
| oss << "Unmatched format for op " << e.op_; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << "Unmatched format for op " << e.op_; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| }; | |||
| class CheckAttr : public Validator { | |||
| public: | |||
| CheckAttr() = default; | |||
| CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {} | |||
| ~CheckAttr() = default; | |||
| void Check(const OpExpander &e) override { | |||
| bool Check(const OpExpander &e) override { | |||
| for (auto &a : attrs_) { | |||
| if (e.attrs_.count(a) == 0) { | |||
| std::ostringstream oss; | |||
| oss << "attr " << a << " does not exist. op " << e.op_; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| private: | |||
| @@ -97,7 +96,7 @@ class CheckAttr : public Validator { | |||
| class SupportFormat : public Validator { | |||
| public: | |||
| void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); } | |||
| void Check(const OpExpander &e) override { | |||
| bool Check(const OpExpander &e) override { | |||
| for (auto &formats : formats_) { | |||
| if (formats.size() != e.inputs_info_.size()) { | |||
| continue; | |||
| @@ -110,12 +109,11 @@ class SupportFormat : public Validator { | |||
| } | |||
| } | |||
| if (match) { | |||
| return; | |||
| return true; | |||
| } | |||
| } | |||
| std::ostringstream oss; | |||
| oss << "unsupported format for op " << e.op_; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(INFO) << "unsupported format for op " << e.op_; | |||
| return false; | |||
| } | |||
| private: | |||
| @@ -123,6 +121,7 @@ class SupportFormat : public Validator { | |||
| }; | |||
| std::vector<int64_t> GetAxisList(const ValuePtr &value); | |||
| ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis); | |||
| } // namespace expanders | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -153,13 +153,12 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) { | |||
| outputs[i].format = AnfAlgo::GetOutputFormat(node, i); | |||
| } | |||
| auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs(); | |||
| try { | |||
| auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext()); | |||
| return LiteGraph2AnfGraph(litegraph); | |||
| } catch (const graphkernel::GKException &e) { | |||
| MS_LOG(INFO) << e.what() << ", undo expanding this op"; | |||
| auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext()); | |||
| if (litegraph == nullptr) { | |||
| MS_LOG(INFO) << "undo expanding " << node->fullname_with_scope(); | |||
| return nullptr; | |||
| } | |||
| return LiteGraph2AnfGraph(litegraph); | |||
| } | |||
| AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) { | |||
| @@ -152,15 +152,6 @@ class OutputNode : public Node { | |||
| void Dump(std::ostringstream &os) const override { ; } | |||
| NType NodeType() override { return NType::Output; } | |||
| }; | |||
| class GKException : public std::exception { | |||
| public: | |||
| explicit GKException(const std::string &message) : msg_(message) {} | |||
| const char *what() const noexcept override { return msg_.c_str(); } | |||
| protected: | |||
| std::string msg_; | |||
| }; | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -200,29 +200,35 @@ NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const | |||
| // default format shape to fractal_Nz format shape | |||
| DShape ToNz(const DShape &default_shape) { | |||
| if (default_shape.size() != 1 && default_shape.size() != 2) { | |||
| throw GKException("shape is too long"); | |||
| } | |||
| DShape output_shape; | |||
| if (default_shape.size() == 1 || (default_shape.size() == 2 && default_shape[0] == 1)) { | |||
| output_shape = {default_shape[default_shape.size() - 1] / 16, 1, 1, 16}; | |||
| if (default_shape[default_shape.size() - 1] % 16 != 0) { | |||
| throw GKException("should be multiplies of 16"); | |||
| constexpr size_t nz_size = 2; | |||
| auto len = default_shape.size(); | |||
| DShape leading_shape; | |||
| DShape tail_shape; | |||
| if (default_shape.size() > nz_size) { | |||
| leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - nz_size); | |||
| } | |||
| if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) { | |||
| // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16) | |||
| if (default_shape.back() % 16 != 0) { | |||
| MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back(); | |||
| } | |||
| } else if (default_shape.size() == 2 || default_shape[1] == 1) { | |||
| output_shape = {1, default_shape[0] / 16, 16, 1}; | |||
| if (default_shape[0] % 16 != 0) { | |||
| throw GKException("should be multiplies of 16"); | |||
| tail_shape = {default_shape.back() / 16, 1, 1, 16}; | |||
| } else if (default_shape.size() >= nz_size || default_shape[1] == 1) { | |||
| // (N, 32, 1) -> (N, 1, 2, 16, 1) | |||
| if (default_shape[len - nz_size] % 16 != 0) { | |||
| MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size]; | |||
| } | |||
| tail_shape = {1, default_shape[0] / 16, 16, 1}; | |||
| } else { | |||
| output_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16}; | |||
| if (default_shape[0] % 16 != 0 || default_shape[1] % 16 != 0) { | |||
| throw GKException("should be multiplies of 16"); | |||
| // (N, 32, 48) -> (N, 3, 2, 16, 16) | |||
| if (default_shape.back() % 16 != 0 || default_shape[len - nz_size] % 16 != 0) { | |||
| MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got " | |||
| << default_shape.back() << " " << default_shape[len - nz_size]; | |||
| } | |||
| tail_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16}; | |||
| } | |||
| return output_shape; | |||
| leading_shape.insert(leading_shape.end(), tail_shape.begin(), tail_shape.end()); | |||
| return leading_shape; | |||
| } | |||
| DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { | |||
| @@ -252,7 +258,7 @@ DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { | |||
| output_shape[i] = align_shape[i]; | |||
| } | |||
| if (output_shape[i] != align_shape[i]) { | |||
| throw GKException("shape broadcast failed"); | |||
| MS_LOG(EXCEPTION) << "Shape broadcast failed. " << output_shape[i] << " vs " << align_shape[i]; | |||
| } | |||
| } | |||
| } | |||
| @@ -272,7 +278,7 @@ DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| })) { | |||
| return BroadcastShape(inputs, true); | |||
| } | |||
| throw GKException("Only support default and fractal_nz"); | |||
| MS_LOG(EXCEPTION) << "Unsupported format."; | |||
| } | |||
| DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| @@ -374,22 +380,20 @@ DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| return new_shape; | |||
| } | |||
| void CheckNd(const std::vector<int64_t> &shape, size_t n) { | |||
| if (shape.size() != n) { | |||
| std::ostringstream info; | |||
| info << "input dimension should be " << n << ", but got " << shape.size(); | |||
| throw GKException(info.str()); | |||
| } | |||
| } | |||
| DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| auto check_nd = [](const std::vector<int64_t> &shape, size_t n) { | |||
| if (shape.size() != n) { | |||
| MS_LOG(EXCEPTION) << "input dimension should be " << n << ", but got " << shape.size(); | |||
| } | |||
| }; | |||
| auto shape0 = inputs[0]->shape; | |||
| auto shape1 = inputs[1]->shape; | |||
| CheckNd(shape0, 4); | |||
| CheckNd(shape1, 4); | |||
| check_nd(shape0, 4); | |||
| check_nd(shape1, 4); | |||
| CHECK_ATTR(attrs, "format"); | |||
| if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC && | |||
| GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) { | |||
| throw GKException("check NHWC format failed"); | |||
| MS_LOG(EXCEPTION) << "check NHWC format failed"; | |||
| } | |||
| auto n = shape0[0]; | |||
| auto h = shape0[1]; | |||
| @@ -405,10 +409,10 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| auto kernel_size = GetListInt(attrs.find("kernel_size")->second); | |||
| auto stride = GetListInt(attrs.find("stride")->second); | |||
| auto dilation = GetListInt(attrs.find("dilation")->second); | |||
| CheckNd(pad_list, 4); | |||
| CheckNd(kernel_size, 2); | |||
| CheckNd(stride, 4); | |||
| CheckNd(dilation, 4); | |||
| check_nd(pad_list, 4); | |||
| check_nd(kernel_size, 2); | |||
| check_nd(stride, 4); | |||
| check_nd(dilation, 4); | |||
| bool has_pad = false; | |||
| if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) { | |||
| has_pad = true; | |||
| @@ -464,19 +468,17 @@ DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) | |||
| std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2}; | |||
| if (perm == nhwc2nchw) return kOpFormat_DEFAULT; | |||
| } | |||
| std::ostringstream info; | |||
| info << "Unsupported Transpose. ori_format = " << ori_format << ", perm = " << attrs.find("perm")->second->ToString(); | |||
| throw GKException(info.str()); | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| std::vector<int64_t> shape0 = inputs[0]->shape; | |||
| std::vector<int64_t> shape1 = inputs[1]->shape; | |||
| if (shape0.size() != 2 || shape1.size() != 2) { | |||
| std::ostringstream info; | |||
| info << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size(); | |||
| throw GKException(info.str()); | |||
| MS_LOG(EXCEPTION) << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size(); | |||
| } | |||
| CHECK_ATTR(attrs, "transpose_a"); | |||
| CHECK_ATTR(attrs, "transpose_b"); | |||
| auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second); | |||
| auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second); | |||
| int64_t m = transpose_a ? shape0[1] : shape0[0]; | |||
| @@ -491,6 +493,7 @@ DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| } | |||
| TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CHECK_ATTR(attrs, "dst_type"); | |||
| if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type; | |||
| auto dst_type = attrs.find("dst_type")->second; | |||
| if (dst_type->isa<Type>()) { | |||
| @@ -502,6 +505,8 @@ TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| std::vector<int64_t> shape0 = inputs[0]->shape; | |||
| size_t n = shape0.size(); | |||
| CHECK_ATTR(attrs, "head"); | |||
| CHECK_ATTR(attrs, "tail"); | |||
| std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second); | |||
| std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second); | |||
| if (pad_before.size() != n || pad_after.size() != n) { | |||
| @@ -518,6 +523,7 @@ DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| std::vector<int64_t> shape0 = inputs[0]->shape; | |||
| size_t n = shape0.size(); | |||
| CHECK_ATTR(attrs, "tail"); | |||
| std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second); | |||
| if (unpad_after.size() != n) { | |||
| MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size(); | |||
| @@ -531,13 +537,12 @@ DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (inputs[0]->type != TypeId::kNumberTypeFloat32) { | |||
| throw GKException("Complex's input[0] should be float32"); | |||
| MS_LOG(EXCEPTION) << "Complex's input[0] should be float32"; | |||
| } | |||
| if (inputs[0]->type != inputs[1]->type) { | |||
| MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch"; | |||
| } | |||
| } | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -251,7 +251,7 @@ class CImagOp : public ElemwiseOp { | |||
| protected: | |||
| void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| if (inputs[0]->type != TypeId::kNumberTypeComplex64) { | |||
| throw GKException("CImag's input[0] should be complex64"); | |||
| MS_LOG(EXCEPTION) << "CImag's input[0] should be complex64"; | |||
| } | |||
| }; | |||
| @@ -266,7 +266,7 @@ class CRealOp : public ElemwiseOp { | |||
| protected: | |||
| void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| if (inputs[0]->type != TypeId::kNumberTypeComplex64) { | |||
| throw GKException("CReal's input[0] should be complex64"); | |||
| MS_LOG(EXCEPTION) << "CReal's input[0] should be complex64"; | |||
| } | |||
| }; | |||
| @@ -229,9 +229,7 @@ class TransformOp { | |||
| perm = perm_map[{format_b_, format_a_}]; | |||
| } | |||
| if (perm.empty()) { | |||
| std::ostringstream oss; | |||
| oss << "unsupported format: " << format_a_ << " to " << format_b_; | |||
| throw graphkernel::GKException(oss.str()); | |||
| MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_; | |||
| } | |||
| auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans"); | |||
| op->SetAttr("perm", MakeValue(perm)); | |||
| @@ -438,23 +436,19 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) { | |||
| bool changed = false; | |||
| for (auto node : todos) { | |||
| if (!AnfAlgo::IsGraphKernel(node)) continue; | |||
| try { | |||
| auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| auto litegraph = AnfGraph2LiteGraph(sub_func_graph); | |||
| if (Process(litegraph)) { | |||
| changed = true; | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| } | |||
| } catch (const graphkernel::GKException &e) { | |||
| MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph"; | |||
| auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| auto litegraph = AnfGraph2LiteGraph(sub_func_graph); | |||
| if (Process(litegraph)) { | |||
| changed = true; | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| } | |||
| } | |||
| return changed; | |||