GitOrigin-RevId: f47ceae726
tags/v1.3.0
| @@ -11,6 +11,11 @@ import io | |||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
| # FIXME: move supportToString flag definition into the param def source file | |||
| ENUM_TO_STRING_SPECIAL_RULES = [ | |||
| ("Elemwise", "Mode"), | |||
| ("ElemwiseMultiType", "Mode") | |||
| ] | |||
| class ConverterWriter(IndentWriterBase): | |||
| _skip_current_param = False | |||
| @@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase): | |||
| def format(v): | |||
| return '\"{}\"'.format(str(v)) | |||
| enum_def += ','.join(format(i) for i in e.members) | |||
| enum_def += "]>" | |||
| enum_def += "]" | |||
| if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | |||
| enum_def += ", 1" # whether generate ToStringTrait | |||
| enum_def += ">" | |||
| self._write("def {} : {};".format(td_class, enum_def)) | |||
| if self._skip_current_param: | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| @@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
| ); | |||
| } | |||
| static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->supportToString()) { | |||
| std::vector<std::string> case_body; | |||
| std::string ename = formatv("{0}::{1}", | |||
| op.getCppClassName(), attr->getEnumName()); | |||
| llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
| case_body.push_back(formatv( | |||
| "case {0}::{1}: return \"{1}\";", ename, v)); | |||
| }); | |||
| os << formatv(R"( | |||
| template <> | |||
| struct ToStringTrait<{0}> { | |||
| std::string operator()({0} e) const { | |||
| switch (e) { | |||
| {1} | |||
| default: | |||
| return "{0}::Unknown"; | |||
| } | |||
| } | |||
| }; | |||
| )", ename, llvm::join(case_body, "\n")); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| auto&& className = op.getCppClassName(); | |||
| os << formatv( | |||
| @@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| os << formatv( | |||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| @@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
| static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
| for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||
| return false; | |||
| } | |||
| @@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||
| std::vector<StringRef> getEnumMembers() const { | |||
| return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | |||
| } | |||
| bool supportToString() const { | |||
| return getBaseRecord()->getValueAsBit("supportToString"); | |||
| } | |||
| }; | |||
| struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | |||
| @@ -170,6 +173,12 @@ public: | |||
| } | |||
| return ret; | |||
| } | |||
| std::string getNameFunctionTemplate() const { | |||
| if (auto f = getDef().getValueAsOptionalString("nameFunction")) { | |||
| return f.getValue().str(); | |||
| } | |||
| return formatv(" return \"{0}\";\n", getCppClassName()); | |||
| } | |||
| }; | |||
| struct MgbHashableOpMixin : public MgbOpBase { | |||
| @@ -241,30 +250,6 @@ private: | |||
| body += " return props_;\n"; | |||
| return body; | |||
| } | |||
| std::string getModeName() const { | |||
| std::string body = formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| getCppClassName() | |||
| ); | |||
| for (auto&& it : getMgbAttributes()) { | |||
| if (it.name == "mode") { | |||
| auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||
| body += " switch (op_.mode){\n"; | |||
| for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||
| body += formatv( | |||
| " case {0}::{1}::{2}:\n", | |||
| getCppClassName(), enumAttr->getEnumName(), enumMember | |||
| ); | |||
| body += formatv(" return \"{0}\";\n", enumMember); | |||
| } | |||
| body += formatv( | |||
| " default: return \"{0}::Unknown\";\n", getCppClassName()); | |||
| body += " }\n"; | |||
| } | |||
| } | |||
| return body; | |||
| } | |||
| public: | |||
| static bool classof(const Operator* op) { | |||
| return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
| @@ -288,12 +273,6 @@ public: | |||
| } | |||
| return getDefaultPropsFunction(); | |||
| } | |||
| std::string getNameFunctionTemplate() const { | |||
| if (getDef().getValueAsBit("usingModeName")) { | |||
| return getModeName(); | |||
| } | |||
| return formatv(" return \"{0}\";\n", getCppClassName()); | |||
| } | |||
| }; | |||
| } // namespace tblgen | |||
| @@ -33,10 +33,11 @@ class MgbHashableAttrMixin { | |||
| string reprFunction = "std::to_string($0)"; | |||
| } | |||
| class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||
| class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> { | |||
| string parentNamespace = namespace; | |||
| string enumName = name; | |||
| list<string> enumMembers = members; | |||
| bit supportToString = toString; | |||
| } | |||
| class MgbAttrWrapper; | |||
| @@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||
| } | |||
| // -- enum types | |||
| class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||
| HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> { | |||
| class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>: | |||
| HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> { | |||
| let storageType = "::mlir::IntegerAttr"; | |||
| let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | |||
| let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | |||
| @@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>: | |||
| class MgbHashableOpMixin { | |||
| string hashFunction = ?; | |||
| string cmpFunction = ?; | |||
| bit usingModeName = 0; | |||
| } | |||
| class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
| @@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits= | |||
| dag extraArguments = (ins); | |||
| // TODO: remove it | |||
| code extraOpdefDecl = ?; | |||
| code nameFunction = ?; | |||
| let arguments = !con( | |||
| !foldl(inputs, params, args, param, !con(args, param.fields)), | |||
| @@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||
| def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | |||
| let inputs = (ins Variadic<AnyType>:$input); | |||
| let results = (outs AnyType); | |||
| let usingModeName = 1; | |||
| let nameFunction = [{ | |||
| return to_string($_self.mode); | |||
| }]; | |||
| } | |||
| def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | |||
| @@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| let usingModeName = 1; | |||
| let nameFunction = [{ | |||
| return to_string($_self.mode); | |||
| }]; | |||
| } | |||
| def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | |||