some dnn backends opr will use agency opr,
for example: softmax cpu naive imp will call elemwise opr,
at model dump stage, we can not get dnn runtime logic,
so we record elemwise mode info at runtime stage.
GitOrigin-RevId: 6528b4c85d
tags/v1.10.0
| @@ -17,6 +17,9 @@ | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_common_elemwise) | |||
| //! this tag will be used at tools/gen_header_for_bin_reduce.py | |||
| //! please do not modify it | |||
| MIDOUT_DECL(megdnn_common_elemwise_mode) | |||
| #include <mutex> | |||
| #include <vector> | |||
| @@ -154,6 +157,88 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||
| #if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL | |||
| megdnn_assert(ret.arity); | |||
| #endif | |||
| //! Some DNN backend OPRS will use proxy OPRS. For example, softmax@cpu Naive imp | |||
| //! will call elemwise OPR. In the model dump stage, we have no information about | |||
| //! this logic, which will lead to the loss of elemwise mode. As a solution, we | |||
| //! record the elemwise mode information by adding the 'midout' case flag in the run | |||
| //! stage. | |||
| #define CB_MODE(mode) \ | |||
| case mode: \ | |||
| MIDOUT_BEGIN(megdnn_common_elemwise_mode, midout_iv(mode)) { return ret; } \ | |||
| MIDOUT_END(); \ | |||
| break; | |||
| switch (mode) { | |||
| CB_MODE(Mode::RELU); | |||
| CB_MODE(Mode::ABS); | |||
| CB_MODE(Mode::ACOS); | |||
| CB_MODE(Mode::ASIN); | |||
| CB_MODE(Mode::CEIL); | |||
| CB_MODE(Mode::COS); | |||
| CB_MODE(Mode::EXP); | |||
| CB_MODE(Mode::EXPM1); | |||
| CB_MODE(Mode::FLOOR); | |||
| CB_MODE(Mode::LOG); | |||
| CB_MODE(Mode::LOG1P); | |||
| CB_MODE(Mode::NEGATE); | |||
| CB_MODE(Mode::SIGMOID); | |||
| CB_MODE(Mode::SIN); | |||
| CB_MODE(Mode::TANH); | |||
| CB_MODE(Mode::ABS_GRAD); | |||
| CB_MODE(Mode::ADD); | |||
| CB_MODE(Mode::FLOOR_DIV); | |||
| CB_MODE(Mode::MAX); | |||
| CB_MODE(Mode::MIN); | |||
| CB_MODE(Mode::MOD); | |||
| CB_MODE(Mode::MUL); | |||
| CB_MODE(Mode::POW); | |||
| CB_MODE(Mode::SIGMOID_GRAD); | |||
| CB_MODE(Mode::SUB); | |||
| CB_MODE(Mode::SWITCH_GT0); | |||
| CB_MODE(Mode::TANH_GRAD); | |||
| CB_MODE(Mode::TRUE_DIV); | |||
| CB_MODE(Mode::LOG_SUM_EXP); | |||
| CB_MODE(Mode::LT); | |||
| CB_MODE(Mode::LEQ); | |||
| CB_MODE(Mode::EQ); | |||
| CB_MODE(Mode::SHL); | |||
| CB_MODE(Mode::SHR); | |||
| CB_MODE(Mode::COND_LEQ_MOV); | |||
| CB_MODE(Mode::FUSE_MUL_ADD3); | |||
| CB_MODE(Mode::FUSE_MUL_ADD4); | |||
| CB_MODE(Mode::FUSE_ADD_RELU); | |||
| CB_MODE(Mode::FUSE_ADD_SIGMOID); | |||
| CB_MODE(Mode::FUSE_ADD_TANH); | |||
| CB_MODE(Mode::FAST_TANH); | |||
| CB_MODE(Mode::FAST_TANH_GRAD); | |||
| CB_MODE(Mode::ROUND); | |||
| CB_MODE(Mode::RMULH); | |||
| CB_MODE(Mode::ATAN2); | |||
| CB_MODE(Mode::ERF); | |||
| CB_MODE(Mode::ERFINV); | |||
| CB_MODE(Mode::ERFC); | |||
| CB_MODE(Mode::ERFCINV); | |||
| CB_MODE(Mode::H_SWISH); | |||
| CB_MODE(Mode::H_SWISH_GRAD); | |||
| CB_MODE(Mode::FUSE_ADD_H_SWISH); | |||
| CB_MODE(Mode::NOT); | |||
| CB_MODE(Mode::AND); | |||
| CB_MODE(Mode::OR); | |||
| CB_MODE(Mode::XOR); | |||
| CB_MODE(Mode::SILU); | |||
| CB_MODE(Mode::SILU_GRAD); | |||
| CB_MODE(Mode::GELU); | |||
| CB_MODE(Mode::GELU_GRAD); | |||
| default: | |||
| megdnn_assert( | |||
| 0, | |||
| "code issue happened!!, please add new elemwise to switch mode."); | |||
| return ret; | |||
| #undef CB_MODE | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -77,18 +77,40 @@ class HeaderGen: | |||
| self._dtypes.add(i) | |||
| for i in data["opr_types"]: | |||
| self._oprs.add(i) | |||
| for i in data["elemwise_modes"]: | |||
| self._elemwise_modes.add(i) | |||
| def extend_midout(self, fname): | |||
| self._midout_files.append(fname) | |||
| def extend_elemwise_mode_info(self, fname): | |||
| for line in open(fname): | |||
| # tag write in dnn/src/common/elemwise/opr_impl.cpp | |||
| idx = line.find("megdnn_common_elemwise_mode") | |||
| if idx > 0: | |||
| cmd = "c++filt -t {}".format(line) | |||
| demangle = subprocess.check_output(cmd, shell=True).decode("utf-8") | |||
| demangle = demangle.replace(">", "").split() | |||
| is_find_number = False | |||
| for i in demangle: | |||
| if i.isnumeric(): | |||
| self._elemwise_modes.add(i) | |||
| is_find_number = True | |||
| break | |||
| assert ( | |||
| is_find_number | |||
| ), "code issue happened!! can not find elemwise mode in: {}".format( | |||
| line | |||
| ) | |||
| def generate(self, fout): | |||
| self._fout = fout | |||
| self._write_def("MGB_BINREDUCE_VERSION", "20190219") | |||
| self._write_def("MGB_BINREDUCE_VERSION", "20220507") | |||
| if self._has_netinfo: | |||
| self._write_dtype() | |||
| if len(self._elemwise_modes) > 0: | |||
| self._write_elemwise_modes() | |||
| if self._has_netinfo: | |||
| self._write_oprs() | |||
| self._write_hash() | |||
| self._write_midout() | |||
| @@ -156,22 +178,32 @@ class HeaderGen: | |||
| with open(fpath) as fin: | |||
| mode_list = [i.strip() for i in fin] | |||
| all_elemwise_modes = set() | |||
| for i in mode_list: | |||
| i = i.split(" ")[0].split("=")[0] | |||
| if i in self._elemwise_modes: | |||
| content = "_cb({})".format(i) | |||
| i_type = i.replace(" ", "").replace("=", " ").split()[0] | |||
| i_id = i.replace(" ", "").replace("=", " ").split()[1] | |||
| all_elemwise_modes.add(i_id) | |||
| if i_id in self._elemwise_modes: | |||
| content = "_cb({})".format(i_type) | |||
| else: | |||
| content = "" | |||
| self._write_def( | |||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( | |||
| i.split(" ")[0].split("=")[0] | |||
| ), | |||
| content, | |||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content, | |||
| ) | |||
| # write end of elemwise macro | |||
| self._write_def( | |||
| "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", | |||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", | |||
| ) | |||
| # finally check all self._elemwise_modes is in all_elemwise_modes | |||
| for i in self._elemwise_modes: | |||
| assert ( | |||
| i in all_elemwise_modes | |||
| ), "code issue happened, can not find elemwise mode: {} in {}".format( | |||
| i, all_elemwise_modes | |||
| ) | |||
| def _write_dtype(self): | |||
| if "Float16" not in self._dtypes: | |||
| @@ -267,6 +299,7 @@ def main(): | |||
| with open(i) as fin: | |||
| if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: | |||
| gen.extend_midout(i) | |||
| gen.extend_elemwise_mode_info(i) | |||
| else: | |||
| fin.seek(0) | |||
| gen.extend_netinfo(json.loads(fin.read())) | |||