GitOrigin-RevId: 05577a8bc8
tags/v1.11.0
| @@ -37,7 +37,6 @@ _ElwMod = builtin.Elemwise.Mode | |||||
| def _elemwise_multi_type(*args, mode, **kwargs): | def _elemwise_multi_type(*args, mode, **kwargs): | ||||
| op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | ||||
| args = convert_inputs(*args) | |||||
| (result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
| return result | return result | ||||
| @@ -249,22 +248,22 @@ class ArrayMethodMixin(abc.ABC): | |||||
| __hash__ = None # due to __eq__ diviates from python convention | __hash__ = None # due to __eq__ diviates from python convention | ||||
| __lt__ = lambda self, value: _elemwise_multi_type( | __lt__ = lambda self, value: _elemwise_multi_type( | ||||
| self, value, mode="lt", dtype="Bool" | |||||
| self, value, mode="lt", dtype="bool" | |||||
| ) | ) | ||||
| __le__ = lambda self, value: _elemwise_multi_type( | __le__ = lambda self, value: _elemwise_multi_type( | ||||
| self, value, mode="leq", dtype="Bool" | |||||
| self, value, mode="leq", dtype="bool" | |||||
| ) | ) | ||||
| __gt__ = lambda self, value: _elemwise_multi_type( | __gt__ = lambda self, value: _elemwise_multi_type( | ||||
| value, self, mode="lt", dtype="Bool" | |||||
| value, self, mode="lt", dtype="bool" | |||||
| ) | ) | ||||
| __ge__ = lambda self, value: _elemwise_multi_type( | __ge__ = lambda self, value: _elemwise_multi_type( | ||||
| value, self, mode="leq", dtype="Bool" | |||||
| value, self, mode="leq", dtype="bool" | |||||
| ) | ) | ||||
| __eq__ = lambda self, value: _elemwise_multi_type( | __eq__ = lambda self, value: _elemwise_multi_type( | ||||
| self, value, mode="eq", dtype="Bool" | |||||
| self, value, mode="eq", dtype="bool" | |||||
| ) | ) | ||||
| __ne__ = lambda self, value: _elemwise_multi_type( | __ne__ = lambda self, value: _elemwise_multi_type( | ||||
| self, value, mode="neq", dtype="Bool" | |||||
| self, value, mode="neq", dtype="bool" | |||||
| ) | ) | ||||
| __neg__ = _unary_elwise(_ElwMod.NEGATE) | __neg__ = _unary_elwise(_ElwMod.NEGATE) | ||||
| @@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor: | |||||
| >>> F.isnan(x).numpy() | >>> F.isnan(x).numpy() | ||||
| array([False, True, False]) | array([False, True, False]) | ||||
| """ | """ | ||||
| return _elemwise_multi_type(inp, mode="isnan", dtype="Bool") | |||||
| return _elemwise_multi_type(inp, mode="isnan", dtype="bool") | |||||
| def isinf(inp: Tensor) -> Tensor: | def isinf(inp: Tensor) -> Tensor: | ||||
| @@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor: | |||||
| >>> F.isinf(x).numpy() | >>> F.isinf(x).numpy() | ||||
| array([False, True, False]) | array([False, True, False]) | ||||
| """ | """ | ||||
| return _elemwise_multi_type(inp, mode="isinf", dtype="Bool") | |||||
| return _elemwise_multi_type(inp, mode="isinf", dtype="bool") | |||||
| def sign(inp: Tensor): | def sign(inp: Tensor): | ||||
| @@ -118,7 +118,7 @@ PyObject* py_apply( | |||||
| tensors[i] = tw->m_tensor->data(); | tensors[i] = tw->m_tensor->data(); | ||||
| } else if ( | } else if ( | ||||
| DTypePromoteCfg::convert_input_enabled && | DTypePromoteCfg::convert_input_enabled && | ||||
| op->same_type<Elemwise>()) { | |||||
| (op->same_type<Elemwise>() || op->same_type<ElemwiseMultiType>())) { | |||||
| tensors[i] = convert_pyinput_to_tensor(i); | tensors[i] = convert_pyinput_to_tensor(i); | ||||
| } else { | } else { | ||||
| PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | ||||
| @@ -53,6 +53,41 @@ mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ValueRefList elemwise_multi_type_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| auto&& elem_op = op.cast_final_safe<ElemwiseMultiType>(); | |||||
| static std::unordered_set<ElemwiseMultiType::Mode> cast_case = { | |||||
| ElemwiseMultiType::Mode::EQ, | |||||
| ElemwiseMultiType::Mode::NEQ, | |||||
| ElemwiseMultiType::Mode::LT, | |||||
| ElemwiseMultiType::Mode::LEQ, | |||||
| }; | |||||
| if (cast_case.find(elem_op.mode) == cast_case.end()) { | |||||
| return imperative::apply(op, inputs); | |||||
| } | |||||
| SmallVector<DType> dtypes(inputs.size()); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| dtypes[i] = *(inputs[i].dtype()); | |||||
| } | |||||
| ValueRefList converted(inputs.size()); | |||||
| mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && | |||||
| DTypePromoteCfg::convert_input_enabled) { | |||||
| converted[i] = imperative::apply( | |||||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
| dtypes[i] = target_dtype; | |||||
| } else { | |||||
| converted[i] = inputs[i]; | |||||
| } | |||||
| } | |||||
| return imperative::apply(op, converted); | |||||
| } | |||||
| ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | ||||
| auto&& elem_op = op.cast_final_safe<Elemwise>(); | auto&& elem_op = op.cast_final_safe<Elemwise>(); | ||||
| @@ -349,6 +384,7 @@ ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| struct DTypePromoteRuleRegistry { | struct DTypePromoteRuleRegistry { | ||||
| DTypePromoteRuleRegistry() { | DTypePromoteRuleRegistry() { | ||||
| register_dtype_promote_rule<Elemwise>(elemwise_rule); | register_dtype_promote_rule<Elemwise>(elemwise_rule); | ||||
| register_dtype_promote_rule<ElemwiseMultiType>(elemwise_multi_type_rule); | |||||
| register_dtype_promote_rule<Concat>(naive_promote_rule); | register_dtype_promote_rule<Concat>(naive_promote_rule); | ||||
| register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | ||||
| register_dtype_promote_rule<Reduce>(reduce_rule); | register_dtype_promote_rule<Reduce>(reduce_rule); | ||||
| @@ -16,52 +16,6 @@ | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | using namespace opr; | ||||
| namespace { | |||||
| //! global operator instance for static inference | |||||
| template <class Opr> | |||||
| class StaticInferOpr { | |||||
| intl::UniqPtrWithCN<Opr> m_opr; | |||||
| MGB_MUTEX m_mtx; | |||||
| public: | |||||
| class Lock { | |||||
| friend class StaticInferOpr; | |||||
| StaticInferOpr* m_owner; | |||||
| explicit Lock(StaticInferOpr* owner) : m_owner{owner} { | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| m_owner->m_mtx.lock(); | |||||
| #endif | |||||
| } | |||||
| public: | |||||
| Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } | |||||
| ~Lock() { | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| if (m_owner) | |||||
| m_owner->m_mtx.unlock(); | |||||
| #endif | |||||
| } | |||||
| Lock& operator=(const Lock&) = delete; | |||||
| Lock& operator=(Lock&&) = delete; | |||||
| intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; } | |||||
| }; | |||||
| //! lock and acquire the operator | |||||
| Lock lock() { | |||||
| Lock ret{this}; | |||||
| if (!m_opr) { | |||||
| m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| }; | |||||
| } // anonymous namespace | |||||
| /* ========================= BatchedDTypePromotion ========================= */ | /* ========================= BatchedDTypePromotion ========================= */ | ||||
| intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars) | intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars) | ||||
| : m_orig_vars{vars} { | : m_orig_vars{vars} { | ||||
| @@ -1,6 +1,6 @@ | |||||
| #include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
| #include "./internal/megdnn_opr_wrapper.inl" | #include "./internal/megdnn_opr_wrapper.inl" | ||||
| #include "megbrain/opr/utility.h" | |||||
| #include "megdnn/oprs/general.h" | #include "megdnn/oprs/general.h" | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -18,6 +18,7 @@ ElemwiseMultiType::ElemwiseMultiType( | |||||
| for (auto i : inputs) { | for (auto i : inputs) { | ||||
| add_input({i}); | add_input({i}); | ||||
| } | } | ||||
| output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
| } | } | ||||
| SymbolVar ElemwiseMultiType::make( | SymbolVar ElemwiseMultiType::make( | ||||
| @@ -52,8 +53,13 @@ void ElemwiseMultiType::init_output_dtype() { | |||||
| void ElemwiseMultiType::scn_do_execute() { | void ElemwiseMultiType::scn_do_execute() { | ||||
| megdnn::TensorNDArray inp_arr(input().size()); | megdnn::TensorNDArray inp_arr(input().size()); | ||||
| for (size_t i = 0; i < input().size(); ++i) { | for (size_t i = 0; i < input().size(); ++i) { | ||||
| if (input()[i]->dev_tensor().empty()) { | |||||
| mgb_assert(output(0)->dev_tensor().empty()); | |||||
| return; | |||||
| } | |||||
| inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); | inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); | ||||
| } | } | ||||
| mgb_assert(!output(0)->dev_tensor().empty()); | |||||
| megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn()); | megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn()); | ||||
| } | } | ||||
| @@ -75,4 +81,120 @@ void ElemwiseMultiType::add_input_layout_constraint() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| ElemwiseMultiType::NodeProp* ElemwiseMultiType::do_make_node_prop() const { | |||||
| auto ret = Super::do_make_node_prop(); | |||||
| for (auto& inp : input()) { | |||||
| ret->add_dep_type_existing_var(inp, NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| void ElemwiseMultiType::init_output_static_infer_desc() { | |||||
| Super::init_output_static_infer_desc(); | |||||
| static StaticInferOpr<megdnn::ElemwiseMultiType> static_infer_opr; | |||||
| using namespace cg::static_infer; | |||||
| auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) { | |||||
| SmallVector<DeviceTensorND> inp_vals(inp.val.size()); | |||||
| for (size_t i = 0; i < inp_vals.size(); ++i) | |||||
| inp_vals[i] = inp.val[i].value(); | |||||
| DType out_dt; | |||||
| auto trait = ModeTrait::from_mode(param().mode); | |||||
| if (trait.need_specify_out_dtype) { | |||||
| auto dtype = config().output_dtype(); | |||||
| mgb_assert(dtype.valid()); | |||||
| out_dt = dtype; | |||||
| } else { | |||||
| DType dtype; | |||||
| trait.check_out(dtype, false); | |||||
| out_dt = dtype; | |||||
| } | |||||
| auto sopr = static_infer_opr.lock(); | |||||
| perform(param().mode, out_dt, dest, inp_vals, sopr()); | |||||
| return true; | |||||
| }; | |||||
| DepVal deps(input().size()); | |||||
| for (size_t i = 0; i < input().size(); ++i) | |||||
| deps[i] = {input(i), DepType::VALUE}; | |||||
| owner_graph()->static_infer_manager().register_value_infer( | |||||
| output(0), {SourceType::DEP, deps, infer_value}); | |||||
| } | |||||
| TensorShape ElemwiseMultiType::get_output_var_shape( | |||||
| Mode mode, const TensorShapeArray& input_shapes) { | |||||
| mgb_assert(input_shapes.size() == ModeTrait::from_mode(mode).arity); | |||||
| TensorShape ret; | |||||
| megdnn::Elemwise::deduce_shape(input_shapes, ret); | |||||
| return ret; | |||||
| } | |||||
| void ElemwiseMultiType::call_megdnn_opr_exec( | |||||
| CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, | |||||
| megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller) { | |||||
| // All Elemwise operations on QuantizedS32/QuantizedS8 are not related to | |||||
| // scale. MegDNN does not support computing Elemwise for | |||||
| // QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before | |||||
| // passing to MegDNN. | |||||
| if (inp.size() && inp[0].layout.dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| auto inp_dtype = inp[0].layout.dtype; | |||||
| DType compute_dtype; | |||||
| if (inp_dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
| compute_dtype = dtype::Int32(); | |||||
| } else if (inp_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| compute_dtype = dtype::Int8(); | |||||
| } else { | |||||
| mgb_throw( | |||||
| MegBrainError, "Unsupported Quantized Elemwise Mode %s: %d on %s", | |||||
| inp[0].layout.dtype.name(), int(opr->param().mode), | |||||
| comp_node.to_string().c_str()); | |||||
| } | |||||
| megdnn::TensorNDArray run_inp(inp); | |||||
| for (size_t i = 0; i < inp.size(); i++) { | |||||
| run_inp[i].layout.dtype = compute_dtype; | |||||
| } | |||||
| megdnn::TensorND run_out = out; | |||||
| run_out.layout.dtype = compute_dtype; | |||||
| opr->exec(run_inp, run_out); | |||||
| return; | |||||
| } | |||||
| opr->exec(inp, out); | |||||
| } | |||||
| void ElemwiseMultiType::perform( | |||||
| Mode mode, DType out_dt, DeviceTensorND& dest, | |||||
| const SmallVector<DeviceTensorND>& inputs, | |||||
| intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr) { | |||||
| megdnn::TensorNDArray dnn_inputs(inputs.size()); | |||||
| TensorShapeArray inp_shapes(inputs.size()); | |||||
| CompNode out_cn; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| auto&& t = inputs[i]; | |||||
| if (!i) { | |||||
| out_cn = t.comp_node(); | |||||
| } else { | |||||
| mgb_assert(t.comp_node() == out_cn); | |||||
| } | |||||
| if (t.shape().is_empty()) { | |||||
| mgb_assert(dest.empty()); | |||||
| return; | |||||
| } | |||||
| inp_shapes[i] = t.shape(); | |||||
| } | |||||
| if (!opr) { | |||||
| opr = intl::create_megdnn_opr<megdnn::ElemwiseMultiType>(out_cn); | |||||
| } else { | |||||
| mgb_assert(out_cn == opr.comp_node()); | |||||
| } | |||||
| out_cn.activate(); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) | |||||
| dnn_inputs[i] = inputs[i].as_megdnn(); | |||||
| dest.comp_node(out_cn).dtype(out_dt).resize(get_output_var_shape(mode, inp_shapes)); | |||||
| opr->param() = {mode}; | |||||
| call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -26,6 +26,14 @@ public: | |||||
| const VarNodeArrayView& inputs, Param param, | const VarNodeArrayView& inputs, Param param, | ||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
| MGE_WIN_DECLSPEC_FUC static TensorShape get_output_var_shape( | |||||
| Mode mode, const TensorShapeArray& input_shapes); | |||||
| MGE_WIN_DECLSPEC_FUC static void perform( | |||||
| Mode mode, DType out_dt, DeviceTensorND& dest, | |||||
| const SmallVector<DeviceTensorND>& inputs, | |||||
| intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr); | |||||
| private: | private: | ||||
| using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait; | using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait; | ||||
| @@ -40,6 +48,14 @@ private: | |||||
| void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
| void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
| NodeProp* do_make_node_prop() const override; | |||||
| void init_output_static_infer_desc() override; | |||||
| static void call_megdnn_opr_exec( | |||||
| CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, | |||||
| megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller); | |||||
| }; | }; | ||||
| //! deprecated; TODO: remove in megbrain 8 | //! deprecated; TODO: remove in megbrain 8 | ||||
| @@ -509,6 +509,49 @@ public: | |||||
| bool is_const() const { return m_is_const; } | bool is_const() const { return m_is_const; } | ||||
| }; | }; | ||||
| //! global operator instance for static inference | |||||
| template <class Opr> | |||||
| class StaticInferOpr { | |||||
| intl::UniqPtrWithCN<Opr> m_opr; | |||||
| MGB_MUTEX m_mtx; | |||||
| public: | |||||
| class Lock { | |||||
| friend class StaticInferOpr; | |||||
| StaticInferOpr* m_owner; | |||||
| explicit Lock(StaticInferOpr* owner) : m_owner{owner} { | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| m_owner->m_mtx.lock(); | |||||
| #endif | |||||
| } | |||||
| public: | |||||
| Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } | |||||
| ~Lock() { | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| if (m_owner) | |||||
| m_owner->m_mtx.unlock(); | |||||
| #endif | |||||
| } | |||||
| Lock& operator=(const Lock&) = delete; | |||||
| Lock& operator=(Lock&&) = delete; | |||||
| intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; } | |||||
| }; | |||||
| //! lock and acquire the operator | |||||
| Lock lock() { | |||||
| Lock ret{this}; | |||||
| if (!m_opr) { | |||||
| m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| }; | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||