GitOrigin-RevId: 92307dd2ca
tags/v1.3.0
| @@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); | |||
| } \ | |||
| } while (0) | |||
| template<typename T, typename SFINAE=void> | |||
| template <typename T, typename SFINAE = void> | |||
| struct pyobj_convert_generic { | |||
| static T from(PyObject* obj) { | |||
| // TODO: remove this guard which is used for pybind11 implicit conversion | |||
| @@ -87,7 +87,12 @@ struct pyobj_convert_generic { | |||
| } | |||
| }; | |||
| template<typename T> | |||
| template <typename T> | |||
| struct EnumTrait { | |||
| static constexpr bool is_bit_combined = false; | |||
| }; | |||
| template <typename T> | |||
| PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { | |||
| PyObject* obj = type->tp_alloc(type, 0); | |||
| T* self = reinterpret_cast<T*>(obj); | |||
| @@ -203,9 +208,10 @@ struct EnumWrapper { | |||
| } | |||
| }; | |||
| template<typename T> | |||
| template <typename T> | |||
| struct pyobj_convert_generic<T, | |||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { | |||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||
| !EnumTrait<T>::is_bit_combined>> { | |||
| using Wrapper = EnumWrapper<T>; | |||
| static T from(PyObject* obj) { | |||
| if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
| @@ -223,6 +229,115 @@ struct pyobj_convert_generic<T, | |||
| } | |||
| }; | |||
| template<typename T> | |||
| struct BitCombinedEnumWrapper { | |||
| static_assert(std::is_enum_v<T>); | |||
| PyObject_HEAD | |||
| T value; | |||
| static const char* name; | |||
| static PyTypeObject type; | |||
| static std::unordered_map<T, std::string> type2str; | |||
| static std::unordered_map<std::string, T> str2type; | |||
| static PyNumberMethods number_methods; | |||
| BitCombinedEnumWrapper() = default; | |||
| BitCombinedEnumWrapper(T v): value(v) {} | |||
| BitCombinedEnumWrapper(std::string&& str) | |||
| : BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} | |||
| std::string to_string() const { | |||
| if (static_cast<uint32_t>(value) == 0) { | |||
| return "None"; | |||
| } else { | |||
| auto ret = std::string(); | |||
| bool first = true; | |||
| for (uint32_t i = 0; i < 32; i++) { | |||
| uint32_t value_int = static_cast<uint32_t>(value); | |||
| auto it = type2str.find(static_cast<T>((1 << i) & value_int)); | |||
| if (it != type2str.end()) { | |||
| if (!first) { | |||
| ret += " + "; | |||
| } else { | |||
| first = false; | |||
| } | |||
| ret += (std::string(name) + "." + it->second); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||
| PyObject* obj = type->tp_alloc(type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||
| return obj; | |||
| } | |||
| static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||
| int input = 1; | |||
| if (PyArg_ParseTuple(args, "|i", &input)){ | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||
| static_cast<T>(input); | |||
| } | |||
| return 0; | |||
| } | |||
| static PyObject* py_repr(PyObject* self) { | |||
| return pyobj_convert_generic<std::string>::to( | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string()); | |||
| } | |||
| static PyObject* py_or(PyObject* self, PyObject* other) { | |||
| if(!(self->ob_type == other->ob_type)){ | |||
| return PyErr_Format( | |||
| PyExc_RuntimeError, | |||
| "Operand in or operator must be the same type."); | |||
| } | |||
| PyObject* obj = type.tp_alloc(&type, 0); | |||
| T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||
| rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||
| static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs)); | |||
| return obj; | |||
| } | |||
| static PyObject* py_and(PyObject* self, PyObject* other) { | |||
| if (!(self->ob_type == other->ob_type)) { | |||
| return PyErr_Format( | |||
| PyExc_RuntimeError, | |||
| "Operand in and operator must be the same type."); | |||
| } | |||
| PyObject* obj = type.tp_alloc(&type, 0); | |||
| T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||
| rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>( | |||
| static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs)); | |||
| return obj; | |||
| } | |||
| static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { | |||
| T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value, | |||
| rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value; | |||
| if (op == Py_EQ || op == Py_NE) { | |||
| RETURN_RICHCOMPARE(lhs, rhs, op); | |||
| } | |||
| Py_RETURN_NOTIMPLEMENTED; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct pyobj_convert_generic<T, | |||
| std::enable_if_t<std::is_enum_v<std::decay_t<T>> && | |||
| EnumTrait<T>::is_bit_combined>> { | |||
| using Wrapper = BitCombinedEnumWrapper<T>; | |||
| static T from(PyObject* obj) { | |||
| if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
| return reinterpret_cast<Wrapper*>(obj)->value; | |||
| } | |||
| // try as string | |||
| // TODO: type checkcd | |||
| return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value; | |||
| } | |||
| static PyObject* to(T t) { | |||
| PyTypeObject* pytype = &Wrapper::type; | |||
| PyObject* obj = pytype->tp_alloc(pytype, 0); | |||
| reinterpret_cast<Wrapper*>(obj)->value = t; | |||
| return obj; | |||
| } | |||
| }; | |||
| void _init_py_op_def(py::module m) { | |||
| using py_op = PyOp(OpDef); | |||
| auto& py_type = PyOpType(OpDef); | |||
| @@ -408,61 +408,58 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||
| os << ";\n\n"; | |||
| } | |||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| static std::string gen_op_def_python_c_extension_enum( | |||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
| llvm::StringRef className) { | |||
| std::string body; | |||
| // generate PyType for enum class member | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv( | |||
| "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName | |||
| ); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| pairStr.push_back(formatv( | |||
| "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||
| enumName); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* EnumWrapper<{0}::{1}>::name = " | |||
| "\"{0}.{1}\";\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| EnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| pairStr.push_back(formatv( | |||
| "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| EnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
| @@ -472,22 +469,140 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
| e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", className, enumName); | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| )", | |||
| className, enumName); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| })", | |||
| className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", | |||
| enumName); | |||
| body += "}\n"; | |||
| return body; | |||
| } | |||
| static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
| llvm::StringRef className) { | |||
| std::string body; | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||
| className, enumName); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject " | |||
| "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> PyNumberMethods " | |||
| "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||
| "= \"{0}.{1}\";\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||
| "bool is_bit_combined = true;};\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "{0}.{1}"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||
| e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||
| e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||
| auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||
| number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||
| number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||
| e_type.tp_as_number = &number_method; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", | |||
| className, enumName); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", | |||
| className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", enumName); | |||
| body += "}\n"; | |||
| )", | |||
| enumName); | |||
| body += "}\n"; | |||
| return body; | |||
| } | |||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| std::string body; | |||
| // generate PyType for enum class member | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->getEnumCombinedFlag()) { | |||
| body += gen_op_def_python_c_extension_bit_combined_enum( | |||
| os, ctx, attr, className); | |||
| } else { | |||
| body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||
| className); | |||
| } | |||
| } | |||
| } | |||
| @@ -141,15 +141,13 @@ R"__usage__( | |||
| )__usage__" | |||
| #if MGB_ENABLE_FASTRUN | |||
| R"__usage__( | |||
| --fast-run | |||
| This param will be deperated later, please replace with param --full-profile. | |||
| --full-profile | |||
| Enable full-profile mode. Operators with multiple algorithms would be profiled | |||
| --full-run | |||
| Enable full-run mode. Operators with multiple algorithms would be profiled | |||
| on the real device with actual input shapes, all algorithms will be profiled | |||
| include naive algorithms. | |||
| See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | |||
| --fast-profile | |||
| Enable fast-profile mode. Operators with multiple algorithms would be profiled | |||
| --fast-run | |||
| Enable fast-run mode. Operators with multiple algorithms would be profiled | |||
| on the real device with actual input shapes, this mode will only profile the | |||
| well optimized algorithms to get the profile result fast. | |||
| See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. | |||
| @@ -519,8 +517,8 @@ struct Args { | |||
| bool disable_assert_throw = false; | |||
| bool share_param_mem = false; | |||
| #if MGB_ENABLE_FASTRUN | |||
| bool use_full_profile = false; | |||
| bool use_fast_profile = false; | |||
| bool use_full_run = false; | |||
| bool use_fast_run = false; | |||
| #endif | |||
| bool reproducible = false; | |||
| std::string fast_run_cache_path; | |||
| @@ -704,13 +702,13 @@ void run_test_st(Args &env) { | |||
| using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| S strategy = S::HEURISTIC; | |||
| #if MGB_ENABLE_FASTRUN | |||
| if (env.use_full_profile) { | |||
| if (env.use_full_run) { | |||
| if (env.reproducible) { | |||
| strategy = S::PROFILE | S::REPRODUCIBLE; | |||
| } else { | |||
| strategy = S::PROFILE; | |||
| } | |||
| } else if (env.use_fast_profile) { | |||
| } else if (env.use_fast_run) { | |||
| strategy = S::PROFILE | S::OPTMIZED; | |||
| } else if (env.reproducible) { | |||
| strategy = S::HEURISTIC | S::REPRODUCIBLE; | |||
| @@ -740,12 +738,12 @@ void run_test_st(Args &env) { | |||
| std::make_shared<InFilePersistentCache>(buf.get(), flen)); | |||
| #if MGB_ENABLE_FASTRUN | |||
| } else { | |||
| mgb_assert(env.use_full_profile || env.use_fast_profile, | |||
| "fast-run or fast-profile should be enabled"); | |||
| mgb_assert(env.use_full_run || env.use_fast_run, | |||
| "fast-run or fast-run should be enabled"); | |||
| PersistentCache::set_impl( | |||
| std::make_shared<InFilePersistentCache>()); | |||
| } | |||
| if (!env.use_full_profile && !env.use_fast_profile) | |||
| if (!env.use_full_run && !env.use_fast_run) | |||
| #endif | |||
| mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | |||
| } | |||
| @@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { | |||
| } | |||
| #if MGB_ENABLE_FASTRUN | |||
| if (!strcmp(argv[i], "--fast-run")) { | |||
| mgb_log_warn( | |||
| "--fast-run param will be deperated later, please replace " | |||
| "with --full-profile or --fast-profile."); | |||
| ret.use_full_profile = true; | |||
| ret.use_fast_run = true; | |||
| continue; | |||
| } | |||
| if (!strcmp(argv[i], "--full-profile")) { | |||
| ret.use_full_profile = true; | |||
| continue; | |||
| } | |||
| if (!strcmp(argv[i], "--fast-profile")) { | |||
| ret.use_fast_profile = true; | |||
| if (!strcmp(argv[i], "--full-run")) { | |||
| ret.use_full_run = true; | |||
| continue; | |||
| } | |||
| #endif | |||
| @@ -12,7 +12,6 @@ | |||
| #pragma once | |||
| #include "megbrain_build_config.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| #include "megdnn/basic_types.h" | |||
| #include <memory> | |||
| @@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { | |||
| } // namespace mgb | |||
| namespace megdnn { | |||
| namespace param { | |||
| MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -18,6 +18,7 @@ | |||
| #include "megbrain/utils/hashable.h" | |||
| #include "megbrain/utils/thin/hash_table.h" | |||
| #include "megbrain/utils/small_vector.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| #include <type_traits> | |||
| @@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ | |||
| } // namespace cg | |||
| } // namespace mgb | |||
| namespace megdnn { | |||
| namespace param { | |||
| MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||
| return ret; | |||
| } | |||
| //! Test whether the algo attribute of a algo match the require | |||
| //! algo_strategy | |||
| static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||
| ExecutionStrategy selected_strategy) { | |||
| bool ret = true; | |||
| if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||
| ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | |||
| } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
| ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace | |||
| namespace mgb { | |||
| @@ -285,8 +298,8 @@ namespace opr { | |||
| template <typename Opr> | |||
| void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
| ExecutionStrategy select_strategy) { | |||
| if (ctx.get_profile_result_from_cache(select_strategy).valid()) | |||
| ExecutionStrategy selected_strategy) { | |||
| if (ctx.get_profile_result_from_cache(selected_strategy).valid()) | |||
| return; | |||
| AlgoChooserProfileCache::Result prof_rst; | |||
| @@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
| algo.name.c_str(), str_on_inp_shape.c_str()); | |||
| ImplExecutionPolicy policy; | |||
| policy.algo = algo.desc; | |||
| ctx.construct_execution_policy(select_strategy, policy); | |||
| if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) | |||
| ctx.construct_execution_policy(selected_strategy, policy); | |||
| if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { | |||
| continue; | |||
| } | |||
| auto algo_attribute = ctx.megdnn_opr() | |||
| ->get_algorithm_from_desc(policy.algo) | |||
| ->attribute(); | |||
| if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { | |||
| mgb_log_debug( | |||
| "skip algo %s, which is not match the profile strategy.", | |||
| algo.name.c_str()); | |||
| continue; | |||
| } | |||
| timer.reset(); | |||
| MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } | |||
| @@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
| template <typename Opr> | |||
| typename AlgoChooser<Opr>::ImplExecutionPolicy | |||
| AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||
| ExecutionStrategy select_strategy, | |||
| ExecutionStrategy selected_strategy, | |||
| bool enable_update) { | |||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | |||
| if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | |||
| @@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||
| to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), | |||
| _item.param, ctx.mgb_opr(), ctx.comp_node(), | |||
| ctx.execution_policy(), ctx.allow_weight_preprocess()); | |||
| AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); | |||
| AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | |||
| }); | |||
| } | |||
| typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | |||
| ctx.construct_execution_policy(select_strategy, policy); | |||
| ctx.construct_execution_policy(selected_strategy, policy); | |||
| return policy; | |||
| MIDOUT_E | |||
| } | |||
| @@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||
| if (!policy.algo.valid()) | |||
| policy = ctx.choose_by_heuristic(opr_strategy); | |||
| return policy; | |||
| } else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||
| } else if (!static_cast<int>(opr_strategy) || | |||
| (opr_strategy & ExecutionStrategy::HEURISTIC)) { | |||
| return ctx.choose_by_heuristic(opr_strategy); | |||
| } | |||
| #if MGB_ENABLE_FASTRUN | |||
| @@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( | |||
| } | |||
| #endif | |||
| else { | |||
| mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); | |||
| mgb_throw(GraphError, "bad ExecutionPolicy strategy"); | |||
| } | |||
| } | |||
| @@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( | |||
| template <typename Opr> | |||
| typename AlgoChooser<Opr>::ImplAlgo | |||
| AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
| ExecutionStrategy select_strategy) const { | |||
| ExecutionStrategy selected_strategy) const { | |||
| MIDOUT_B(Opr, | |||
| midout_iv(MGB_HASH_STR( | |||
| "AlgoChooser::ExeContext::get_profile_result_from_cache"))) | |||
| @@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
| if (prof.empty()) | |||
| return {}; | |||
| for (auto&& i : prof) { | |||
| if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||
| if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) || | |||
| static_cast<AlgoAttribute>(i.attribute) & | |||
| AlgoAttribute::REPRODUCIBLE) { | |||
| auto iter = algo_map.find(i.algo); | |||
| @@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
| template <typename Opr> | |||
| typename AlgoChooser<Opr>::ImplExecutionPolicy | |||
| AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
| ExecutionStrategy select_strategy) const { | |||
| ExecutionStrategy selected_strategy) const { | |||
| if (m_execution_policy.workspace_limit != | |||
| std::numeric_limits<decltype( | |||
| m_execution_policy.workspace_limit)>::max()) { | |||
| @@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
| "workspace_limit should not be setted if choose algo by " | |||
| "heuristic"); | |||
| } | |||
| bool reproducible = static_cast<bool>(select_strategy & | |||
| bool reproducible = static_cast<bool>(selected_strategy & | |||
| ExecutionStrategy::REPRODUCIBLE); | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
| @@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
| _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | |||
| m_allow_weight_preprocess); | |||
| policy.sub_policy.push_back( | |||
| sub_ctx.choose_by_heuristic(select_strategy)); | |||
| sub_ctx.choose_by_heuristic(selected_strategy)); | |||
| }); | |||
| return policy; | |||
| @@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { | |||
| template <typename Opr> | |||
| void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
| ExecutionStrategy select_strategy, | |||
| ExecutionStrategy selected_strategy, | |||
| typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | |||
| bool retrive_from_cache) const { | |||
| bool reproducible = static_cast<bool>(select_strategy & | |||
| bool reproducible = static_cast<bool>(selected_strategy & | |||
| ExecutionStrategy::REPRODUCIBLE); | |||
| if (!policy.algo.valid()) { | |||
| if (retrive_from_cache) { | |||
| policy.algo = | |||
| get_profile_result_from_cache(select_strategy).desc; | |||
| get_profile_result_from_cache(selected_strategy).desc; | |||
| } else { | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
| @@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
| _item.param, m_base_mgb_opr, m_cn, m_execution_policy, | |||
| m_allow_weight_preprocess); | |||
| policy.sub_policy.push_back({}); | |||
| sub_ctx.construct_execution_policy(select_strategy, | |||
| sub_ctx.construct_execution_policy(selected_strategy, | |||
| policy.sub_policy.back(), | |||
| retrive_from_cache); | |||
| }); | |||
| @@ -110,7 +110,7 @@ public: | |||
| const FixedTensorLayouts& layouts() const { return m_layouts; } | |||
| ImplExecutionPolicy choose_by_heuristic( | |||
| ExecutionStrategy select_strategy) const; | |||
| ExecutionStrategy selected_strategy) const; | |||
| //! get all candidate algos, and the one choose_by_heuristic() is | |||
| //! put first | |||
| @@ -134,17 +134,17 @@ public: | |||
| //! get all profile algorithm from cache, return invalid if not exists | |||
| ImplAlgo get_profile_result_from_cache( | |||
| ExecutionStrategy select_strategy) const; | |||
| ExecutionStrategy selected_strategy) const; | |||
| /** | |||
| * \brief construct execution policy from cache or heuristic. | |||
| * | |||
| * \param select_strategy select algo which matched this strategy | |||
| * \param selected_strategy select algo which matched this strategy | |||
| * \param policy execution policy | |||
| * \param retrive_from_cache retrive algo from cache if set True, get | |||
| * from heuristic otherwise. | |||
| */ | |||
| void construct_execution_policy(ExecutionStrategy select_strategy, | |||
| void construct_execution_policy(ExecutionStrategy selected_strategy, | |||
| ImplExecutionPolicy& policy, | |||
| bool retrive_from_cache = true) const; | |||
| @@ -161,10 +161,10 @@ private: | |||
| //! profile and save to cache | |||
| static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); | |||
| static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | |||
| static ImplExecutionPolicy choose_by_profile( | |||
| ExeContext& ctx, ExecutionStrategy select_strategy, | |||
| ExeContext& ctx, ExecutionStrategy selected_strategy, | |||
| bool enable_update = true); | |||
| public: | |||