| @@ -315,5 +315,9 @@ r""" | |||||
| """), | """), | ||||
| has_out_dtype=True) | has_out_dtype=True) | ||||
| decl_opr('FakeQuant', | |||||
| inputs=[Doc('src','input tenosr'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor')], | |||||
| params='FakeQuant') | |||||
| # vim: ft=python | # vim: ft=python | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| #include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
| #include "megbrain/opr/dnn/fake_quant.h" | |||||
| #include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
| @@ -423,6 +424,8 @@ namespace opr { | |||||
| MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | ||||
| MGB_SEREG_OPR(BatchConvBiasForward, 0); | MGB_SEREG_OPR(BatchConvBiasForward, 0); | ||||
| MGB_SEREG_OPR(FakeQuant, 3); | |||||
| MGB_SEREG_OPR(FakeQuantBackward, 4); | |||||
| } // namespace opr | } // namespace opr | ||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * \file src/opr/impl/dnn/fake_quant.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/opr/dnn/fake_quant.h" | |||||
| #include "../internal/megdnn_opr_wrapper.inl" | |||||
| #include "megbrain/graph/grad_impl.h" | |||||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| using namespace mgb; | |||||
| using namespace opr; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantForward); | |||||
| MEGDNN_OPR_INIT3(FakeQuantForward, "fakequant_fwd"); | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(FakeQuantForward) { | |||||
| if (wrt_idx == 0) { | |||||
| // wrt src | |||||
| SymbolVar grad = | |||||
| FakeQuantBackward::make(out_grad[0], opr.input(0), opr.input(1), | |||||
| opr.input(2), opr.param()); | |||||
| return grad.node(); | |||||
| } else { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantBackward); | |||||
| MEGDNN_OPR_INIT4(FakeQuantBackward, "fakequant_bwd", 1, true); | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * \file src/opr/include/megbrain/opr/dnn/fake_quant.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megdnn/oprs.h" | |||||
| namespace mgb { | |||||
| namespace opr { | |||||
| MGB_DEFINE_OPR_CLASS(FakeQuantForward, | |||||
| intl::MegDNNOprWrapperFwd<megdnn::FakeQuantForward>) // { | |||||
| public: | |||||
| FakeQuantForward(VarNode* src, VarNode* scale, VarNode* zero_point, | |||||
| const Param& param, const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, | |||||
| const Param& param = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| }; // namespace opr | |||||
| using FakeQuant = FakeQuantForward; | |||||
| MGB_DEFINE_OPR_CLASS(FakeQuantBackward, | |||||
| intl::MegDNNOprWrapperBwd<megdnn::FakeQuantBackward>) // { | |||||
| public: | |||||
| FakeQuantBackward(VarNode* diff, VarNode* input, VarNode* scale, | |||||
| VarNode* zero_point, const Param& param, | |||||
| const OperatorNodeConfig& config); | |||||
| static SymbolVar make(SymbolVar diff, SymbolVar input, SymbolVar scale, | |||||
| SymbolVar zero_point, const Param& param = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| }; | |||||
| } // namespace mgb | |||||
| } // namespace opr | |||||
| @@ -102,6 +102,7 @@ union OperatorParam { | |||||
| param.AdaptivePooling = 70, | param.AdaptivePooling = 70, | ||||
| param.NvOf = 71, | param.NvOf = 71, | ||||
| param.DctChannelSelect = 72, | param.DctChannelSelect = 72, | ||||
| param.FakeQuant = 73, | |||||
| } | } | ||||
| table Operator { | table Operator { | ||||