GitOrigin-RevId: 92307dd2ca
tags/v1.3.0
| @@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); | |||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| template<typename T, typename SFINAE=void> | |||||
| template <typename T, typename SFINAE = void> | |||||
| struct pyobj_convert_generic { | struct pyobj_convert_generic { | ||||
| static T from(PyObject* obj) { | static T from(PyObject* obj) { | ||||
| // TODO: remove this guard which is used for pybind11 implicit conversion | // TODO: remove this guard which is used for pybind11 implicit conversion | ||||
| @@ -87,7 +87,12 @@ struct pyobj_convert_generic { | |||||
| } | } | ||||
| }; | }; | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| struct EnumTrait { | |||||
| static constexpr bool is_bit_combined = false; | |||||
| }; | |||||
| template <typename T> | |||||
| PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | ||||
| PyObject* obj = type->tp_alloc(type, 0); | PyObject* obj = type->tp_alloc(type, 0); | ||||
| T* self = reinterpret_cast<T*>(obj); | T* self = reinterpret_cast<T*>(obj); | ||||
| @@ -203,9 +208,10 @@ struct EnumWrapper { | |||||
| } | } | ||||
| }; | }; | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| struct pyobj_convert_generic<T, | struct pyobj_convert_generic<T, | ||||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||||
| !EnumTrait<T>::is_bit_combined>> { | |||||
| using Wrapper = EnumWrapper<T>; | using Wrapper = EnumWrapper<T>; | ||||
| static T from(PyObject* obj) { | static T from(PyObject* obj) { | ||||
| if (PyObject_TypeCheck(obj, &Wrapper::type)) { | if (PyObject_TypeCheck(obj, &Wrapper::type)) { | ||||
| @@ -223,6 +229,115 @@ struct pyobj_convert_generic<T, | |||||
| } | } | ||||
| }; | }; | ||||
| template<typename T> | |||||
| struct BitCombinedEnumWrapper { | |||||
| static_assert(std::is_enum_v<T>); | |||||
| PyObject_HEAD | |||||
| T value; | |||||
| static const char* name; | |||||
| static PyTypeObject type; | |||||
| static std::unordered_map<T, std::string> type2str; | |||||
| static std::unordered_map<std::string, T> str2type; | |||||
| static PyNumberMethods number_methods; | |||||
| BitCombinedEnumWrapper() = default; | |||||
| BitCombinedEnumWrapper(T v): value(v) {} | |||||
| BitCombinedEnumWrapper(std::string&& str) | |||||
| : BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} | |||||
| std::string to_string() const { | |||||
| if (static_cast<uint32_t>(value) == 0) { | |||||
| return "None"; | |||||
| } else { | |||||
| auto ret = std::string(); | |||||
| bool first = true; | |||||
| for (uint32_t i = 0; i < 32; i++) { | |||||
| uint32_t value_int = static_cast<uint32_t>(value); | |||||
| auto it = type2str.find(static_cast<T>((1 << i) & value_int)); | |||||
| if (it != type2str.end()) { | |||||
| if (!first) { | |||||
| ret += " + "; | |||||
| } else { | |||||
| first = false; | |||||
| } | |||||
| ret += (std::string(name) + "." + it->second); | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||||
| PyObject* obj = type->tp_alloc(type, 0); | |||||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||||
| return obj; | |||||
| } | |||||
| static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||||
| int input = 1; | |||||
| if (PyArg_ParseTuple(args, "|i", &input)){ | |||||
| reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||||
| static_cast<T>(input); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| static PyObject* py_repr(PyObject* self) { | |||||
| return pyobj_convert_generic<std::string>::to( | |||||
| reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string()); | |||||
| } | |||||
| static PyObject* py_or(PyObject* self, PyObject* other) { | |||||
| if(!(self->ob_type == other->ob_type)){ | |||||
| return PyErr_Format( | |||||
| PyExc_RuntimeError, | |||||
| "Operand in or operator must be the same type."); | |||||
| } | |||||
| PyObject* obj = type.tp_alloc(&type, 0); | |||||
| T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||||
| rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||||
| static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs)); | |||||
| return obj; | |||||
| } | |||||
| static PyObject* py_and(PyObject* self, PyObject* other) { | |||||
| if (!(self->ob_type == other->ob_type)) { | |||||
| return PyErr_Format( | |||||
| PyExc_RuntimeError, | |||||
| "Operand in and operator must be the same type."); | |||||
| } | |||||
| PyObject* obj = type.tp_alloc(&type, 0); | |||||
| T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||||
| rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||||
| static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs)); | |||||
| return obj; | |||||
| } | |||||
| static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { | |||||
| T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||||
| rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||||
| if (op == Py_EQ || op == Py_NE) { | |||||
| RETURN_RICHCOMPARE(lhs, rhs, op); | |||||
| } | |||||
| Py_RETURN_NOTIMPLEMENTED; | |||||
| } | |||||
| }; | |||||
| template <typename T> | |||||
| struct pyobj_convert_generic<T, | |||||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||||
| EnumTrait<T>::is_bit_combined>> { | |||||
| using Wrapper = BitCombinedEnumWrapper<T>; | |||||
| static T from(PyObject* obj) { | |||||
| if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||||
| return reinterpret_cast<Wrapper*>(obj)->value; | |||||
| } | |||||
| // try as string | |||||
| // TODO: type checkcd | |||||
| return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||||
| } | |||||
| static PyObject* to(T t) { | |||||
| PyTypeObject* pytype = &Wrapper::type; | |||||
| PyObject* obj = pytype->tp_alloc(pytype, 0); | |||||
| reinterpret_cast<Wrapper*>(obj)->value = t; | |||||
| return obj; | |||||
| } | |||||
| }; | |||||
| void _init_py_op_def(py::module m) { | void _init_py_op_def(py::module m) { | ||||
| using py_op = PyOp(OpDef); | using py_op = PyOp(OpDef); | ||||
| auto& py_type = PyOpType(OpDef); | auto& py_type = PyOpType(OpDef); | ||||
| @@ -408,61 +408,58 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||||
| os << ";\n\n"; | os << ";\n\n"; | ||||
| } | } | ||||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
| auto className = op.getCppClassName(); | |||||
| static std::string gen_op_def_python_c_extension_enum( | |||||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||||
| llvm::StringRef className) { | |||||
| std::string body; | std::string body; | ||||
| // generate PyType for enum class member | |||||
| for (auto&& i : op.getMgbAttributes()) { | |||||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
| unsigned int enumID; | |||||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
| auto&& aliasBase = alias->getAliasBase(); | |||||
| enumID = | |||||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||||
| .getBaseRecord()->getID(); | |||||
| } else { | |||||
| enumID = attr->getBaseRecord()->getID(); | |||||
| } | |||||
| auto&& enumAlias = ctx.enumAlias; | |||||
| auto&& iter = enumAlias.find(enumID); | |||||
| auto enumName = attr->getEnumName(); | |||||
| body += "{\n"; | |||||
| body += formatv( | |||||
| "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||||
| ); | |||||
| if (iter == enumAlias.end()) { | |||||
| os << formatv( | |||||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||||
| className, enumName); | |||||
| os << formatv( | |||||
| "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||||
| className, enumName); | |||||
| std::vector<std::string> pairStr; | |||||
| for (auto&& i: attr->getEnumMembers()) { | |||||
| pairStr.push_back(formatv( | |||||
| "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
| className, enumName, i)); | |||||
| } | |||||
| os << formatv(R"( | |||||
| unsigned int enumID; | |||||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
| auto&& aliasBase = alias->getAliasBase(); | |||||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
| } else { | |||||
| enumID = attr->getBaseRecord()->getID(); | |||||
| } | |||||
| auto&& enumAlias = ctx.enumAlias; | |||||
| auto&& iter = enumAlias.find(enumID); | |||||
| auto enumName = attr->getEnumName(); | |||||
| body += "{\n"; | |||||
| body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||||
| enumName); | |||||
| if (iter == enumAlias.end()) { | |||||
| os << formatv( | |||||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||||
| className, enumName); | |||||
| os << formatv( | |||||
| "template<> const char* EnumWrapper<{0}::{1}>::name = " | |||||
| "\"{0}.{1}\";\n", | |||||
| className, enumName); | |||||
| std::vector<std::string> pairStr; | |||||
| for (auto&& i : attr->getEnumMembers()) { | |||||
| pairStr.push_back( | |||||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
| className, enumName, i)); | |||||
| } | |||||
| os << formatv(R"( | |||||
| template<> std::unordered_map<std::string, {0}::{1}> | template<> std::unordered_map<std::string, {0}::{1}> | ||||
| EnumWrapper<{0}::{1}>::str2type = {{ | EnumWrapper<{0}::{1}>::str2type = {{ | ||||
| {2} | {2} | ||||
| }; | }; | ||||
| )", className, enumName, llvm::join(pairStr, ", ")); | |||||
| pairStr.clear(); | |||||
| for (auto&& i: attr->getEnumMembers()) { | |||||
| pairStr.push_back(formatv( | |||||
| "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
| className, enumName, i)); | |||||
| } | |||||
| os << formatv(R"( | |||||
| )", | |||||
| className, enumName, llvm::join(pairStr, ", ")); | |||||
| pairStr.clear(); | |||||
| for (auto&& i : attr->getEnumMembers()) { | |||||
| pairStr.push_back( | |||||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
| className, enumName, i)); | |||||
| } | |||||
| os << formatv(R"( | |||||
| template<> std::unordered_map<{0}::{1}, std::string> | template<> std::unordered_map<{0}::{1}, std::string> | ||||
| EnumWrapper<{0}::{1}>::type2str = {{ | EnumWrapper<{0}::{1}>::type2str = {{ | ||||
| {2} | {2} | ||||
| }; | }; | ||||
| )", className, enumName, llvm::join(pairStr, ", ")); | |||||
| body += formatv(R"( | |||||
| )", | |||||
| className, enumName, llvm::join(pairStr, ", ")); | |||||
| body += formatv(R"( | |||||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | ||||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | ||||
| e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | ||||
| @@ -472,22 +469,140 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
| e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | ||||
| e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | ||||
| mgb_assert(PyType_Ready(&e_type) >= 0); | mgb_assert(PyType_Ready(&e_type) >= 0); | ||||
| )", className, enumName); | |||||
| for (auto&& i: attr->getEnumMembers()) { | |||||
| body += formatv(R"({{ | |||||
| )", | |||||
| className, enumName); | |||||
| for (auto&& i : attr->getEnumMembers()) { | |||||
| body += formatv(R"({{ | |||||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | PyObject* inst = e_type.tp_alloc(&e_type, 0); | ||||
| reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | ||||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | ||||
| })", className, enumName, i); | |||||
| } | |||||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
| } | |||||
| body += formatv(R"( | |||||
| })", | |||||
| className, enumName, i); | |||||
| } | |||||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
| } | |||||
| body += formatv(R"( | |||||
| PyType_Modified(&e_type); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||||
| )", | |||||
| enumName); | |||||
| body += "}\n"; | |||||
| return body; | |||||
| } | |||||
| static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||||
| llvm::StringRef className) { | |||||
| std::string body; | |||||
| unsigned int enumID; | |||||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||||
| auto&& aliasBase = alias->getAliasBase(); | |||||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||||
| } else { | |||||
| enumID = attr->getBaseRecord()->getID(); | |||||
| } | |||||
| auto&& enumAlias = ctx.enumAlias; | |||||
| auto&& iter = enumAlias.find(enumID); | |||||
| auto enumName = attr->getEnumName(); | |||||
| body += "{\n"; | |||||
| body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||||
| className, enumName); | |||||
| if (iter == enumAlias.end()) { | |||||
| os << formatv( | |||||
| "template<> PyTypeObject " | |||||
| "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||||
| className, enumName); | |||||
| os << formatv( | |||||
| "template<> PyNumberMethods " | |||||
| "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||||
| className, enumName); | |||||
| os << formatv( | |||||
| "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||||
| "= \"{0}.{1}\";\n", | |||||
| className, enumName); | |||||
| os << formatv( | |||||
| "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||||
| "bool is_bit_combined = true;};\n", | |||||
| className, enumName); | |||||
| std::vector<std::string> pairStr; | |||||
| for (auto&& i : attr->getEnumMembers()) { | |||||
| pairStr.push_back( | |||||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||||
| className, enumName, i)); | |||||
| } | |||||
| os << formatv(R"( | |||||
| template<> std::unordered_map<std::string, {0}::{1}> | |||||
| BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||||
| {2} | |||||
| }; | |||||
| )", | |||||
| className, enumName, llvm::join(pairStr, ", ")); | |||||
| pairStr.clear(); | |||||
| for (auto&& i : attr->getEnumMembers()) { | |||||
| pairStr.push_back( | |||||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||||
| className, enumName, i)); | |||||
| } | |||||
| os << formatv(R"( | |||||
| template<> std::unordered_map<{0}::{1}, std::string> | |||||
| BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||||
| {2} | |||||
| }; | |||||
| )", | |||||
| className, enumName, llvm::join(pairStr, ", ")); | |||||
| body += formatv(R"( | |||||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||||
| e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
| e_type.tp_doc = "{0}.{1}"; | |||||
| e_type.tp_base = &PyBaseObject_Type; | |||||
| e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||||
| e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||||
| e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||||
| e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||||
| auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||||
| number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||||
| number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||||
| e_type.tp_as_number = &number_method; | |||||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||||
| )", | |||||
| className, enumName); | |||||
| for (auto&& i : attr->getEnumMembers()) { | |||||
| body += formatv(R"({{ | |||||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||||
| reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||||
| })", | |||||
| className, enumName, i); | |||||
| } | |||||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||||
| } | |||||
| body += formatv(R"( | |||||
| PyType_Modified(&e_type); | PyType_Modified(&e_type); | ||||
| mgb_assert(PyDict_SetItemString( | mgb_assert(PyDict_SetItemString( | ||||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | ||||
| )", enumName); | |||||
| body += "}\n"; | |||||
| )", | |||||
| enumName); | |||||
| body += "}\n"; | |||||
| return body; | |||||
| } | |||||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||||
| auto className = op.getCppClassName(); | |||||
| std::string body; | |||||
| // generate PyType for enum class member | |||||
| for (auto&& i : op.getMgbAttributes()) { | |||||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||||
| if (attr->getEnumCombinedFlag()) { | |||||
| body += gen_op_def_python_c_extension_bit_combined_enum( | |||||
| os, ctx, attr, className); | |||||
| } else { | |||||
| body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||||
| className); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -141,15 +141,13 @@ R"__usage__( | |||||
| )__usage__" | )__usage__" | ||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| R"__usage__( | R"__usage__( | ||||
| --fast-run | |||||
| This param will be deperated later, please replace with param --full-profile. | |||||
| --full-profile | |||||
| Enable full-profile mode. Operators with multiple algorithms would be profiled | |||||
| --full-run | |||||
| Enable full-run mode. Operators with multiple algorithms would be profiled | |||||
| on the real device with actual input shapes, all algorithms will be profiled | on the real device with actual input shapes, all algorithms will be profiled | ||||
| include naive algorithms. | include naive algorithms. | ||||
| See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | ||||
| --fast-profile | |||||
| Enable fast-profile mode. Operators with multiple algorithms would be profiled | |||||
| --fast-run | |||||
| Enable fast-run mode. Operators with multiple algorithms would be profiled | |||||
| on the real device with actual input shapes, this mode will only profile the | on the real device with actual input shapes, this mode will only profile the | ||||
| well optimized algorithms to get the profile result fast. | well optimized algorithms to get the profile result fast. | ||||
| See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | ||||
| @@ -519,8 +517,8 @@ struct Args { | |||||
| bool disable_assert_throw = false; | bool disable_assert_throw = false; | ||||
| bool share_param_mem = false; | bool share_param_mem = false; | ||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| bool use_full_profile = false; | |||||
| bool use_fast_profile = false; | |||||
| bool use_full_run = false; | |||||
| bool use_fast_run = false; | |||||
| #endif | #endif | ||||
| bool reproducible = false; | bool reproducible = false; | ||||
| std::string fast_run_cache_path; | std::string fast_run_cache_path; | ||||
| @@ -704,13 +702,13 @@ void run_test_st(Args &env) { | |||||
| using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | ||||
| S strategy = S::HEURISTIC; | S strategy = S::HEURISTIC; | ||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| if (env.use_full_profile) { | |||||
| if (env.use_full_run) { | |||||
| if (env.reproducible) { | if (env.reproducible) { | ||||
| strategy = S::PROFILE | S::REPRODUCIBLE; | strategy = S::PROFILE | S::REPRODUCIBLE; | ||||
| } else { | } else { | ||||
| strategy = S::PROFILE; | strategy = S::PROFILE; | ||||
| } | } | ||||
| } else if (env.use_fast_profile) { | |||||
| } else if (env.use_fast_run) { | |||||
| strategy = S::PROFILE | S::OPTMIZED; | strategy = S::PROFILE | S::OPTMIZED; | ||||
| } else if (env.reproducible) { | } else if (env.reproducible) { | ||||
| strategy = S::HEURISTIC | S::REPRODUCIBLE; | strategy = S::HEURISTIC | S::REPRODUCIBLE; | ||||
| @@ -740,12 +738,12 @@ void run_test_st(Args &env) { | |||||
| std::make_shared<InFilePersistentCache>(buf.get(), flen)); | std::make_shared<InFilePersistentCache>(buf.get(), flen)); | ||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| } else { | } else { | ||||
| mgb_assert(env.use_full_profile || env.use_fast_profile, | |||||
| "fast-run or fast-profile should be enabled"); | |||||
| mgb_assert(env.use_full_run || env.use_fast_run, | |||||
| "fast-run or fast-run should be enabled"); | |||||
| PersistentCache::set_impl( | PersistentCache::set_impl( | ||||
| std::make_shared<InFilePersistentCache>()); | std::make_shared<InFilePersistentCache>()); | ||||
| } | } | ||||
| if (!env.use_full_profile && !env.use_fast_profile) | |||||
| if (!env.use_full_run && !env.use_fast_run) | |||||
| #endif | #endif | ||||
| mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | ||||
| } | } | ||||
| @@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| } | } | ||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| if (!strcmp(argv[i], "--fast-run")) { | if (!strcmp(argv[i], "--fast-run")) { | ||||
| mgb_log_warn( | |||||
| "--fast-run param will be deperated later, please replace " | |||||
| "with --full-profile or --fast-profile."); | |||||
| ret.use_full_profile = true; | |||||
| ret.use_fast_run = true; | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (!strcmp(argv[i], "--full-profile")) { | |||||
| ret.use_full_profile = true; | |||||
| continue; | |||||
| } | |||||
| if (!strcmp(argv[i], "--fast-profile")) { | |||||
| ret.use_fast_profile = true; | |||||
| if (!strcmp(argv[i], "--full-run")) { | |||||
| ret.use_full_run = true; | |||||
| continue; | continue; | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -12,7 +12,6 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #include "megbrain/opr/param_defs.h" | |||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #include <memory> | #include <memory> | ||||
| @@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { | |||||
| } // namespace mgb | } // namespace mgb | ||||
| namespace megdnn { | |||||
| namespace param { | |||||
| MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "megbrain/utils/hashable.h" | #include "megbrain/utils/hashable.h" | ||||
| #include "megbrain/utils/thin/hash_table.h" | #include "megbrain/utils/thin/hash_table.h" | ||||
| #include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
| #include "megbrain/opr/param_defs.h" | |||||
| #include <type_traits> | #include <type_traits> | ||||
| @@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ | |||||
| } // namespace cg | } // namespace cg | ||||
| } // namespace mgb | } // namespace mgb | ||||
| namespace megdnn { | |||||
| namespace param { | |||||
| MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| //! Test whether the algo attribute of a algo match the require | |||||
| //! algo_strategy | |||||
| static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||||
| ExecutionStrategy selected_strategy) { | |||||
| bool ret = true; | |||||
| if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||||
| ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | |||||
| } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
| ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -285,8 +298,8 @@ namespace opr { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| void AlgoChooser<Opr>::profile(ExeContext& ctx, | void AlgoChooser<Opr>::profile(ExeContext& ctx, | ||||
| ExecutionStrategy select_strategy) { | |||||
| if (ctx.get_profile_result_from_cache(select_strategy).valid()) | |||||
| ExecutionStrategy selected_strategy) { | |||||
| if (ctx.get_profile_result_from_cache(selected_strategy).valid()) | |||||
| return; | return; | ||||
| AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
| @@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
| algo.name.c_str(), str_on_inp_shape.c_str()); | algo.name.c_str(), str_on_inp_shape.c_str()); | ||||
| ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
| policy.algo = algo.desc; | policy.algo = algo.desc; | ||||
| ctx.construct_execution_policy(select_strategy, policy); | |||||
| if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) | |||||
| ctx.construct_execution_policy(selected_strategy, policy); | |||||
| if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { | |||||
| continue; | continue; | ||||
| } | |||||
| auto algo_attribute = ctx.megdnn_opr() | |||||
| ->get_algorithm_from_desc(policy.algo) | |||||
| ->attribute(); | |||||
| if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { | |||||
| mgb_log_debug( | |||||
| "skip algo %s, which is not match the profile strategy.", | |||||
| algo.name.c_str()); | |||||
| continue; | |||||
| } | |||||
| timer.reset(); | timer.reset(); | ||||
| MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } | MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } | ||||
| @@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
| template <typename Opr> | template <typename Opr> | ||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy | typename AlgoChooser<Opr>::ImplExecutionPolicy | ||||
| AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | ||||
| ExecutionStrategy select_strategy, | |||||
| ExecutionStrategy selected_strategy, | |||||
| bool enable_update) { | bool enable_update) { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | ||||
| if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | ||||
| @@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||||
| to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | ||||
| _item.param, ctx.mgb_opr(), ctx.comp_node(), | _item.param, ctx.mgb_opr(), ctx.comp_node(), | ||||
| ctx.execution_policy(), ctx.allow_weight_preprocess()); | ctx.execution_policy(), ctx.allow_weight_preprocess()); | ||||
| AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); | |||||
| AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | |||||
| }); | }); | ||||
| } | } | ||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | ||||
| ctx.construct_execution_policy(select_strategy, policy); | |||||
| ctx.construct_execution_policy(selected_strategy, policy); | |||||
| return policy; | return policy; | ||||
| MIDOUT_E | MIDOUT_E | ||||
| } | } | ||||
| @@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||||
| if (!policy.algo.valid()) | if (!policy.algo.valid()) | ||||
| policy = ctx.choose_by_heuristic(opr_strategy); | policy = ctx.choose_by_heuristic(opr_strategy); | ||||
| return policy; | return policy; | ||||
| } else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||||
| } else if (!static_cast<int>(opr_strategy) || | |||||
| (opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||||
| return ctx.choose_by_heuristic(opr_strategy); | return ctx.choose_by_heuristic(opr_strategy); | ||||
| } | } | ||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| @@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||||
| } | } | ||||
| #endif | #endif | ||||
| else { | else { | ||||
| mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); | |||||
| mgb_throw(GraphError, "bad ExecutionPolicy strategy"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( | |||||
| template <typename Opr> | template <typename Opr> | ||||
| typename AlgoChooser<Opr>::ImplAlgo | typename AlgoChooser<Opr>::ImplAlgo | ||||
| AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | ||||
| ExecutionStrategy select_strategy) const { | |||||
| ExecutionStrategy selected_strategy) const { | |||||
| MIDOUT_B(Opr, | MIDOUT_B(Opr, | ||||
| midout_iv(MGB_HASH_STR( | midout_iv(MGB_HASH_STR( | ||||
| "AlgoChooser::ExeContext::get_profile_result_from_cache"))) | "AlgoChooser::ExeContext::get_profile_result_from_cache"))) | ||||
| @@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
| if (prof.empty()) | if (prof.empty()) | ||||
| return {}; | return {}; | ||||
| for (auto&& i : prof) { | for (auto&& i : prof) { | ||||
| if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||||
| if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||||
| static_cast<AlgoAttribute>(i.attribute) & | static_cast<AlgoAttribute>(i.attribute) & | ||||
| AlgoAttribute::REPRODUCIBLE) { | AlgoAttribute::REPRODUCIBLE) { | ||||
| auto iter = algo_map.find(i.algo); | auto iter = algo_map.find(i.algo); | ||||
| @@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
| template <typename Opr> | template <typename Opr> | ||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy | typename AlgoChooser<Opr>::ImplExecutionPolicy | ||||
| AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | ||||
| ExecutionStrategy select_strategy) const { | |||||
| ExecutionStrategy selected_strategy) const { | |||||
| if (m_execution_policy.workspace_limit != | if (m_execution_policy.workspace_limit != | ||||
| std::numeric_limits<decltype( | std::numeric_limits<decltype( | ||||
| m_execution_policy.workspace_limit)>::max()) { | m_execution_policy.workspace_limit)>::max()) { | ||||
| @@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
| "workspace_limit should not be setted if choose algo by " | "workspace_limit should not be setted if choose algo by " | ||||
| "heuristic"); | "heuristic"); | ||||
| } | } | ||||
| bool reproducible = static_cast<bool>(select_strategy & | |||||
| bool reproducible = static_cast<bool>(selected_strategy & | |||||
| ExecutionStrategy::REPRODUCIBLE); | ExecutionStrategy::REPRODUCIBLE); | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| @@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
| _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | ||||
| m_allow_weight_preprocess); | m_allow_weight_preprocess); | ||||
| policy.sub_policy.push_back( | policy.sub_policy.push_back( | ||||
| sub_ctx.choose_by_heuristic(select_strategy)); | |||||
| sub_ctx.choose_by_heuristic(selected_strategy)); | |||||
| }); | }); | ||||
| return policy; | return policy; | ||||
| @@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | ||||
| ExecutionStrategy select_strategy, | |||||
| ExecutionStrategy selected_strategy, | |||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | ||||
| bool retrive_from_cache) const { | bool retrive_from_cache) const { | ||||
| bool reproducible = static_cast<bool>(select_strategy & | |||||
| bool reproducible = static_cast<bool>(selected_strategy & | |||||
| ExecutionStrategy::REPRODUCIBLE); | ExecutionStrategy::REPRODUCIBLE); | ||||
| if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
| if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
| policy.algo = | policy.algo = | ||||
| get_profile_result_from_cache(select_strategy).desc; | |||||
| get_profile_result_from_cache(selected_strategy).desc; | |||||
| } else { | } else { | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| @@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | ||||
| m_allow_weight_preprocess); | m_allow_weight_preprocess); | ||||
| policy.sub_policy.push_back({}); | policy.sub_policy.push_back({}); | ||||
| sub_ctx.construct_execution_policy(select_strategy, | |||||
| sub_ctx.construct_execution_policy(selected_strategy, | |||||
| policy.sub_policy.back(), | policy.sub_policy.back(), | ||||
| retrive_from_cache); | retrive_from_cache); | ||||
| }); | }); | ||||
| @@ -110,7 +110,7 @@ public: | |||||
| const FixedTensorLayouts& layouts() const { return m_layouts; } | const FixedTensorLayouts& layouts() const { return m_layouts; } | ||||
| ImplExecutionPolicy choose_by_heuristic( | ImplExecutionPolicy choose_by_heuristic( | ||||
| ExecutionStrategy select_strategy) const; | |||||
| ExecutionStrategy selected_strategy) const; | |||||
| //! get all candidate algos, and the one choose_by_heuristic() is | //! get all candidate algos, and the one choose_by_heuristic() is | ||||
| //! put first | //! put first | ||||
| @@ -134,17 +134,17 @@ public: | |||||
| //! get all profile algorithm from cache, return invalid if not exists | //! get all profile algorithm from cache, return invalid if not exists | ||||
| ImplAlgo get_profile_result_from_cache( | ImplAlgo get_profile_result_from_cache( | ||||
| ExecutionStrategy select_strategy) const; | |||||
| ExecutionStrategy selected_strategy) const; | |||||
| /** | /** | ||||
| * \brief construct execution policy from cache or heuristic. | * \brief construct execution policy from cache or heuristic. | ||||
| * | * | ||||
| * \param select_strategy select algo which matched this strategy | |||||
| * \param selected_strategy select algo which matched this strategy | |||||
| * \param policy execution policy | * \param policy execution policy | ||||
| * \param retrive_from_cache retrive algo from cache if set True, get | * \param retrive_from_cache retrive algo from cache if set True, get | ||||
| * from heuristic otherwise. | * from heuristic otherwise. | ||||
| */ | */ | ||||
| void construct_execution_policy(ExecutionStrategy select_strategy, | |||||
| void construct_execution_policy(ExecutionStrategy selected_strategy, | |||||
| ImplExecutionPolicy& policy, | ImplExecutionPolicy& policy, | ||||
| bool retrive_from_cache = true) const; | bool retrive_from_cache = true) const; | ||||
| @@ -161,10 +161,10 @@ private: | |||||
| //! profile and save to cache | //! profile and save to cache | ||||
| static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); | |||||
| static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | |||||
| static ImplExecutionPolicy choose_by_profile( | static ImplExecutionPolicy choose_by_profile( | ||||
| ExeContext& ctx, ExecutionStrategy select_strategy, | |||||
| ExeContext& ctx, ExecutionStrategy selected_strategy, | |||||
| bool enable_update = true); | bool enable_update = true); | ||||
| public: | public: | ||||