GitOrigin-RevId: f47ceae726
tags/v1.3.0
| @@ -11,6 +11,11 @@ import io | |||||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | from gen_param_defs import member_defs, ParamDef, IndentWriterBase | ||||
| # FIXME: move supportToString flag definition into the param def source file | |||||
| ENUM_TO_STRING_SPECIAL_RULES = [ | |||||
| ("Elemwise", "Mode"), | |||||
| ("ElemwiseMultiType", "Mode") | |||||
| ] | |||||
| class ConverterWriter(IndentWriterBase): | class ConverterWriter(IndentWriterBase): | ||||
| _skip_current_param = False | _skip_current_param = False | ||||
| @@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase): | |||||
| def format(v): | def format(v): | ||||
| return '\"{}\"'.format(str(v)) | return '\"{}\"'.format(str(v)) | ||||
| enum_def += ','.join(format(i) for i in e.members) | enum_def += ','.join(format(i) for i in e.members) | ||||
| enum_def += "]>" | |||||
| enum_def += "]" | |||||
| if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | |||||
| enum_def += ", 1" # whether generate ToStringTrait | |||||
| enum_def += ">" | |||||
| self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
| if self._skip_current_param: | if self._skip_current_param: | ||||
| @@ -12,6 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
| #include "megbrain/imperative/utils/to_string.h" | |||||
| #include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
| #include "megbrain/opr/param_defs.h" | #include "megbrain/opr/param_defs.h" | ||||
| @@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||||
| ); | ); | ||||
| } | } | ||||
| static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||||
| for (auto &&i : op.getMgbAttributes()) { | |||||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
| if (attr->supportToString()) { | |||||
| std::vector<std::string> case_body; | |||||
| std::string ename = formatv("{0}::{1}", | |||||
| op.getCppClassName(), attr->getEnumName()); | |||||
| llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||||
| case_body.push_back(formatv( | |||||
| "case {0}::{1}: return \"{1}\";", ename, v)); | |||||
| }); | |||||
| os << formatv(R"( | |||||
| template <> | |||||
| struct ToStringTrait<{0}> { | |||||
| std::string operator()({0} e) const { | |||||
| switch (e) { | |||||
| {1} | |||||
| default: | |||||
| return "{0}::Unknown"; | |||||
| } | |||||
| } | |||||
| }; | |||||
| )", ename, llvm::join(case_body, "\n")); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | ||||
| auto&& className = op.getCppClassName(); | auto&& className = op.getCppClassName(); | ||||
| os << formatv( | os << formatv( | ||||
| @@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
| os << formatv( | os << formatv( | ||||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | ||||
| ); | ); | ||||
| os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||||
| os << formatv( | |||||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
| " static_cast<void>(op_);\n", | |||||
| className | |||||
| ); | |||||
| ctx.withSelf("op_"); | |||||
| os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||||
| os << "}\n"; | os << "}\n"; | ||||
| os << "} // anonymous namespace\n"; | os << "} // anonymous namespace\n"; | ||||
| @@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||||
| static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | ||||
| for_each_operator(os, keeper, gen_op_def_c_header_single); | for_each_operator(os, keeper, gen_op_def_c_header_single); | ||||
| for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||||
| std::vector<StringRef> getEnumMembers() const { | std::vector<StringRef> getEnumMembers() const { | ||||
| return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | ||||
| } | } | ||||
| bool supportToString() const { | |||||
| return getBaseRecord()->getValueAsBit("supportToString"); | |||||
| } | |||||
| }; | }; | ||||
| struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | ||||
| @@ -170,6 +173,12 @@ public: | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::string getNameFunctionTemplate() const { | |||||
| if (auto f = getDef().getValueAsOptionalString("nameFunction")) { | |||||
| return f.getValue().str(); | |||||
| } | |||||
| return formatv(" return \"{0}\";\n", getCppClassName()); | |||||
| } | |||||
| }; | }; | ||||
| struct MgbHashableOpMixin : public MgbOpBase { | struct MgbHashableOpMixin : public MgbOpBase { | ||||
| @@ -241,30 +250,6 @@ private: | |||||
| body += " return props_;\n"; | body += " return props_;\n"; | ||||
| return body; | return body; | ||||
| } | } | ||||
| std::string getModeName() const { | |||||
| std::string body = formatv( | |||||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
| " static_cast<void>(op_);\n", | |||||
| getCppClassName() | |||||
| ); | |||||
| for (auto&& it : getMgbAttributes()) { | |||||
| if (it.name == "mode") { | |||||
| auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||||
| body += " switch (op_.mode){\n"; | |||||
| for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||||
| body += formatv( | |||||
| " case {0}::{1}::{2}:\n", | |||||
| getCppClassName(), enumAttr->getEnumName(), enumMember | |||||
| ); | |||||
| body += formatv(" return \"{0}\";\n", enumMember); | |||||
| } | |||||
| body += formatv( | |||||
| " default: return \"{0}::Unknown\";\n", getCppClassName()); | |||||
| body += " }\n"; | |||||
| } | |||||
| } | |||||
| return body; | |||||
| } | |||||
| public: | public: | ||||
| static bool classof(const Operator* op) { | static bool classof(const Operator* op) { | ||||
| return op->getDef().isSubClassOf("MgbHashableOpMixin"); | return op->getDef().isSubClassOf("MgbHashableOpMixin"); | ||||
| @@ -288,12 +273,6 @@ public: | |||||
| } | } | ||||
| return getDefaultPropsFunction(); | return getDefaultPropsFunction(); | ||||
| } | } | ||||
| std::string getNameFunctionTemplate() const { | |||||
| if (getDef().getValueAsBit("usingModeName")) { | |||||
| return getModeName(); | |||||
| } | |||||
| return formatv(" return \"{0}\";\n", getCppClassName()); | |||||
| } | |||||
| }; | }; | ||||
| } // namespace tblgen | } // namespace tblgen | ||||
| @@ -33,10 +33,11 @@ class MgbHashableAttrMixin { | |||||
| string reprFunction = "std::to_string($0)"; | string reprFunction = "std::to_string($0)"; | ||||
| } | } | ||||
| class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||||
| class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> { | |||||
| string parentNamespace = namespace; | string parentNamespace = namespace; | ||||
| string enumName = name; | string enumName = name; | ||||
| list<string> enumMembers = members; | list<string> enumMembers = members; | ||||
| bit supportToString = toString; | |||||
| } | } | ||||
| class MgbAttrWrapper; | class MgbAttrWrapper; | ||||
| @@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||||
| } | } | ||||
| // -- enum types | // -- enum types | ||||
| class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||||
| HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> { | |||||
| class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>: | |||||
| HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> { | |||||
| let storageType = "::mlir::IntegerAttr"; | let storageType = "::mlir::IntegerAttr"; | ||||
| let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | ||||
| let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | ||||
| @@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>: | |||||
| class MgbHashableOpMixin { | class MgbHashableOpMixin { | ||||
| string hashFunction = ?; | string hashFunction = ?; | ||||
| string cmpFunction = ?; | string cmpFunction = ?; | ||||
| bit usingModeName = 0; | |||||
| } | } | ||||
| class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | ||||
| @@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits= | |||||
| dag extraArguments = (ins); | dag extraArguments = (ins); | ||||
| // TODO: remove it | // TODO: remove it | ||||
| code extraOpdefDecl = ?; | code extraOpdefDecl = ?; | ||||
| code nameFunction = ?; | |||||
| let arguments = !con( | let arguments = !con( | ||||
| !foldl(inputs, params, args, param, !con(args, param.fields)), | !foldl(inputs, params, args, param, !con(args, param.fields)), | ||||
| @@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
| def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | ||||
| let inputs = (ins Variadic<AnyType>:$input); | let inputs = (ins Variadic<AnyType>:$input); | ||||
| let results = (outs AnyType); | let results = (outs AnyType); | ||||
| let usingModeName = 1; | |||||
| let nameFunction = [{ | |||||
| return to_string($_self.mode); | |||||
| }]; | |||||
| } | } | ||||
| def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | ||||
| @@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||||
| let extraArguments = (ins | let extraArguments = (ins | ||||
| MgbDTypeAttr:$dtype | MgbDTypeAttr:$dtype | ||||
| ); | ); | ||||
| let usingModeName = 1; | |||||
| let nameFunction = [{ | |||||
| return to_string($_self.mode); | |||||
| }]; | |||||
| } | } | ||||
| def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | ||||