| @@ -1122,7 +1122,8 @@ def batch_norm( | |||||
| momentum: float = 0.9, | momentum: float = 0.9, | ||||
| eps: float = 1e-5, | eps: float = 1e-5, | ||||
| inplace: bool = True, | inplace: bool = True, | ||||
| compute_mode="default" | |||||
| compute_mode="default", | |||||
| param_dim="dim_1c11" | |||||
| ): | ): | ||||
| r"""Applies batch normalization to the input. | r"""Applies batch normalization to the input. | ||||
| @@ -1147,16 +1148,23 @@ def batch_norm( | |||||
| if inp.ndim != 4: | if inp.ndim != 4: | ||||
| raise NotImplementedError("batch_norm for ndim != 4") | raise NotImplementedError("batch_norm for ndim != 4") | ||||
| C = inp.shape[1] | |||||
| if param_dim == "dim_1c11": | |||||
| C = inp.shape[1] | |||||
| pshape = (1, C, 1, 1) | |||||
| elif param_dim == "dim_111c": | |||||
| C = inp.shape[3] | |||||
| pshape = (1, 1, 1, C) | |||||
| else: | |||||
| raise ValueError("Invalid param_dim {}".format(param_dim)) | |||||
| def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | ||||
| shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
| return result | return result | ||||
| elif x.ndim == 1: | elif x.ndim == 1: | ||||
| shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) | |||||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Reshape(), x, shape) | (result,) = apply(builtin.Reshape(), x, shape) | ||||
| return result | return result | ||||
| return x | return x | ||||
| @@ -1183,19 +1191,19 @@ def batch_norm( | |||||
| if not training: | if not training: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" | |||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim=param_dim | |||||
| ) | ) | ||||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ||||
| return ret | return ret | ||||
| else: | else: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" | |||||
| avg_factor=1 - momentum, epsilon=eps, param_dim=param_dim | |||||
| ) | ) | ||||
| if has_mean or has_var: | if has_mean or has_var: | ||||
| running_mean = make_full_if_none(running_mean, 0) | running_mean = make_full_if_none(running_mean, 0) | ||||
| running_var = make_full_if_none(running_var, 1) | running_var = make_full_if_none(running_var, 1) | ||||
| new_mean, new_var, _, _, inp = apply( | |||||
| new_mean, new_var, *_, inp = apply( | |||||
| op, inp, weight, bias, running_mean, running_var | op, inp, weight, bias, running_mean, running_var | ||||
| ) | ) | ||||
| if not has_mean: | if not has_mean: | ||||
| @@ -1213,7 +1221,7 @@ def batch_norm( | |||||
| else: | else: | ||||
| return inp, new_mean, new_var | return inp, new_mean, new_var | ||||
| else: | else: | ||||
| (_, _, inp,) = apply(op, inp, weight, bias) | |||||
| inp = apply(op, inp, weight, bias)[-1] | |||||
| return inp | return inp | ||||
| @@ -27,6 +27,7 @@ class _BatchNorm(Module): | |||||
| track_running_stats=True, | track_running_stats=True, | ||||
| freeze=False, | freeze=False, | ||||
| compute_mode="default", | compute_mode="default", | ||||
| param_dim="dim_1c11", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| super(_BatchNorm, self).__init__(**kwargs) | super(_BatchNorm, self).__init__(**kwargs) | ||||
| @@ -38,6 +39,7 @@ class _BatchNorm(Module): | |||||
| self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
| self.freeze = freeze | self.freeze = freeze | ||||
| self.compute_mode = compute_mode | self.compute_mode = compute_mode | ||||
| self.param_dim = param_dim | |||||
| if self.freeze: | if self.freeze: | ||||
| assert ( | assert ( | ||||
| self._track_running_stats_saved | self._track_running_stats_saved | ||||
| @@ -125,6 +127,7 @@ class _BatchNorm(Module): | |||||
| momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
| eps=self.eps, | eps=self.eps, | ||||
| compute_mode=self.compute_mode, | compute_mode=self.compute_mode, | ||||
| param_dim=self.param_dim, | |||||
| ) | ) | ||||
| if _ndims != 4: | if _ndims != 4: | ||||
| @@ -811,7 +811,8 @@ def test_batch_conv_bias(): | |||||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | ||||
| def test_conv2d_io16c32(): | |||||
| def test_conv2d_autocast(): | |||||
| """check amp's result is equal to manually converted result""" | |||||
| amp.enabled = True | amp.enabled = True | ||||
| inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | ||||
| weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) | weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) | ||||
| @@ -918,11 +919,14 @@ def test_layer_norm(): | |||||
| assert abs(outvar.mean()) < 1e-7 | assert abs(outvar.mean()) < 1e-7 | ||||
| def test_batchnorm2d_io16c32(): | |||||
| def test_batchnorm2d_autocast(): | |||||
| """check amp's result is equal to manually converted result""" | |||||
| amp.enabled = True | amp.enabled = True | ||||
| inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) | |||||
| weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) | |||||
| bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) | |||||
| tshape = (1, 224, 224, 3) | |||||
| pshape = (1, 1, 1, 3) | |||||
| inp = tensor(np.random.randn(*tshape), dtype=np.float32) | |||||
| weight = tensor(np.ones(pshape, dtype=np.float32)) | |||||
| bias = tensor(np.zeros(pshape, dtype=np.float32)) | |||||
| out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) | out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) | ||||
| @@ -51,16 +51,16 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | ||||
| // need running mean/variance | // need running mean/variance | ||||
| bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; | bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; | ||||
| size_t nr_out = need_stat? 5 : 3; | |||||
| size_t nr_out = need_stat? 6 : 4; | |||||
| SmallVector<LogicalTensorDesc> out_shapes(nr_out); | SmallVector<LogicalTensorDesc> out_shapes(nr_out); | ||||
| auto&& i0 = inputs[0]; | auto&& i0 = inputs[0]; | ||||
| auto&& i1 = inputs[1]; | auto&& i1 = inputs[1]; | ||||
| // [running_mean, running_var,] save_mean, save_var | // [running_mean, running_var,] save_mean, save_var | ||||
| for (size_t i = 0; i < nr_out-1; ++ i) { | |||||
| for (size_t i = 0; i < nr_out-2; ++ i) { | |||||
| out_shapes[i] = {i1.layout, i1.comp_node}; | out_shapes[i] = {i1.layout, i1.comp_node}; | ||||
| } | } | ||||
| // output tensor | |||||
| out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; | |||||
| out_shapes[nr_out-2] = {TensorLayout({0}, dtype::Byte()), i0.comp_node}; // reserve | |||||
| out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; // output | |||||
| return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0}; | return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0}; | ||||
| } | } | ||||
| @@ -689,7 +689,8 @@ ProxyGraph::make_backward_graph( | |||||
| output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); | output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); | ||||
| } | } | ||||
| auto output_grads = make_input_place_holders(output_descs); | auto output_grads = make_input_place_holders(output_descs); | ||||
| mgb_assert(output_grads.size() == output_has_grad.size()); | |||||
| mgb_assert(output_grads.size() == output_has_grad.size(), "%d vs %d", | |||||
| output_grads.size(), output_has_grad.size()); | |||||
| bool any_input_has_grad = false; | bool any_input_has_grad = false; | ||||
| for (size_t i = 0; i < output_grads.size(); ++ i) { | for (size_t i = 0; i < output_grads.size(); ++ i) { | ||||
| if (!output_has_grad[i]) { | if (!output_has_grad[i]) { | ||||
| @@ -207,7 +207,7 @@ TEST(TestImperative, BatchNormGrad) { | |||||
| attr.param.write_pod(param); | attr.param.write_pod(param); | ||||
| OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | ||||
| {true, true, true, false, false}, | {true, true, true, false, false}, | ||||
| {false, false, false, false, true}); | |||||
| {false, false, false, false, false, true}); | |||||
| } | } | ||||
| { | { | ||||
| auto op = OprAttr::make("BatchNorm"); | auto op = OprAttr::make("BatchNorm"); | ||||
| @@ -216,7 +216,7 @@ TEST(TestImperative, BatchNormGrad) { | |||||
| param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
| attr.param.write_pod(param); | attr.param.write_pod(param); | ||||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, | OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, | ||||
| {false, false, true}); | |||||
| {false, false, false, true}); | |||||
| } | } | ||||
| } | } | ||||
| @@ -99,7 +99,7 @@ UNUSED void print(const char* s) { | |||||
| OprChecker::OprChecker(std::shared_ptr<OpDef> opdef) | OprChecker::OprChecker(std::shared_ptr<OpDef> opdef) | ||||
| : m_op(opdef) {} | : m_op(opdef) {} | ||||
| void OprChecker::run(std::vector<InputSpec> inp_keys) { | |||||
| void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) { | |||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| size_t nr_inps = inp_keys.size(); | size_t nr_inps = inp_keys.size(); | ||||
| SmallVector<HostTensorND> host_inp(nr_inps); | SmallVector<HostTensorND> host_inp(nr_inps); | ||||
| @@ -151,6 +151,8 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) { | |||||
| func->execute().wait(); // run last because it may contain inplace operations | func->execute().wait(); // run last because it may contain inplace operations | ||||
| for(size_t i = 0; i < nr_oups; ++ i) { | for(size_t i = 0; i < nr_oups; ++ i) { | ||||
| if (bypass.find(i) != bypass.end()) | |||||
| continue; | |||||
| MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]); | MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]); | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ class OprChecker { | |||||
| public: | public: | ||||
| using InputSpec = std::variant<HostTensorND, TensorShape>; | using InputSpec = std::variant<HostTensorND, TensorShape>; | ||||
| OprChecker(std::shared_ptr<OpDef> opdef); | OprChecker(std::shared_ptr<OpDef> opdef); | ||||
| void run(std::vector<InputSpec> inp_shapes); | |||||
| void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass={}); | |||||
| private: | private: | ||||
| std::shared_ptr<OpDef> m_op; | std::shared_ptr<OpDef> m_op; | ||||
| }; | }; | ||||
| @@ -73,7 +73,7 @@ TEST(TestImperative, BatchNorm) { | |||||
| TensorShape{1, C, 1, 1}, | TensorShape{1, C, 1, 1}, | ||||
| TensorShape{1, C, 1, 1}, | TensorShape{1, C, 1, 1}, | ||||
| TensorShape{1, C, 1, 1} | TensorShape{1, C, 1, 1} | ||||
| }); | |||||
| }, {4}); | |||||
| } | } | ||||
| TEST(TestImperative, Concat) { | TEST(TestImperative, Concat) { | ||||
| @@ -1766,7 +1766,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { | |||||
| x.dtype().name(), res.dtype().name()); | x.dtype().name(), res.dtype().name()); | ||||
| } | } | ||||
| rewriter.replace_var( | rewriter.replace_var( | ||||
| opr->output(4), res.node(), | |||||
| opr->output(5), res.node(), | |||||
| mgb_cstr_log( | mgb_cstr_log( | ||||
| "replace batch_norm(x, scale, bias, mean, " | "replace batch_norm(x, scale, bias, mean, " | ||||
| "varience) " | "varience) " | ||||
| @@ -35,6 +35,7 @@ | |||||
| #include "megdnn/tensor_format.h" | #include "megdnn/tensor_format.h" | ||||
| #include <random> | #include <random> | ||||
| #include <vector> | |||||
| #if MGB_CUDA | #if MGB_CUDA | ||||
| #include <cudnn.h> | #include <cudnn.h> | ||||
| @@ -1665,44 +1666,49 @@ TEST(TestGoptInference, concatbypass) { | |||||
| TEST(TestGoptInference, ConvertBatchNormPass) { | TEST(TestGoptInference, ConvertBatchNormPass) { | ||||
| auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
| HostTensorGenerator<> gen(0, 1, 0); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
| .rename(name); | |||||
| }; | |||||
| using Param = opr::BatchNorm::Param; | |||||
| Param param(Param::ParamDim::DIM_1C11, Param::FwdMode::INFERENCE); | |||||
| TensorShape shp = {1, 3, 1, 1}; | |||||
| auto x = mkvar("x", {2, 3, 16, 24}), scale = mkcvar("scale", shp), | |||||
| bias = mkcvar("bias", shp), mean = mkcvar("mean", shp); | |||||
| auto host_variance = gen(shp, cn); | |||||
| for (size_t i = 0; i < shp.total_nr_elems(); ++i) { | |||||
| host_variance->ptr<float>()[i] = | |||||
| std::abs(host_variance->ptr<float>()[i]); | |||||
| } | |||||
| auto variance = opr::SharedDeviceTensor::make(*graph, *host_variance) | |||||
| .rename("variance"); | |||||
| auto y = opr::BatchNorm::make(x, scale, bias, mean, variance, param)[4]; | |||||
| SymbolVar y_opt; | |||||
| unpack_vector(gopt::optimize_for_inference( | |||||
| {y}, gopt::OptimizeForInferenceOptions{}), | |||||
| y_opt); | |||||
| ASSERT_EQ(0u, find_opr_num<opr::BatchNorm>(y_opt)); | |||||
| graph->compile({{y_opt, {}}}) | |||||
| ->to_json() | |||||
| ->writeto_fpath( | |||||
| output_file("TestGoptInference.ConvertBatchNormPass.json")); | |||||
| std::vector<TensorShape> shps = {{1, 3, 1, 1}, {1, 1, 1, 3}}, | |||||
| xshps = {{2, 3, 16, 24}, {2, 16, 24, 3}}; | |||||
| for (int t = 0; t < 2; t++) { | |||||
| HostTensorGenerator<> gen(0, 1, 0); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
| .rename(name); | |||||
| }; | |||||
| using Param = opr::BatchNorm::Param; | |||||
| Param::ParamDim param_dim = t == 0 ? Param::ParamDim::DIM_1C11 : Param::ParamDim::DIM_111C; | |||||
| Param param(param_dim, Param::FwdMode::INFERENCE); | |||||
| TensorShape shp = shps[t], xshp = xshps[t]; | |||||
| auto x = mkvar("x", xshp), scale = mkcvar("scale", shp), | |||||
| bias = mkcvar("bias", shp), mean = mkcvar("mean", shp); | |||||
| auto host_variance = gen(shp, cn); | |||||
| for (size_t i = 0; i < shp.total_nr_elems(); ++i) { | |||||
| host_variance->ptr<float>()[i] = | |||||
| std::abs(host_variance->ptr<float>()[i]); | |||||
| } | |||||
| auto variance = opr::SharedDeviceTensor::make(*graph, *host_variance) | |||||
| .rename("variance"); | |||||
| auto y = opr::BatchNorm::make(x, scale, bias, mean, variance, param)[5]; | |||||
| SymbolVar y_opt; | |||||
| unpack_vector(gopt::optimize_for_inference( | |||||
| {y}, gopt::OptimizeForInferenceOptions{}), | |||||
| y_opt); | |||||
| ASSERT_EQ(0u, find_opr_num<opr::BatchNorm>(y_opt)); | |||||
| graph->compile({{y_opt, {}}}) | |||||
| ->to_json() | |||||
| ->writeto_fpath( | |||||
| output_file("TestGoptInference.ConvertBatchNormPass.json")); | |||||
| HostTensorND host_y, host_y_opt; | |||||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||||
| make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); | |||||
| HostTensorND host_y, host_y_opt; | |||||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||||
| make_callback_copy(y_opt, host_y_opt)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); | |||||
| } | |||||
| } | } | ||||
| TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { | TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { | ||||
| @@ -62,10 +62,12 @@ BatchNormForward::BatchNormForward(VarNode *x, | |||||
| } | } | ||||
| init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
| output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
| add_input({x, scale, bias, mean, variance}); | add_input({x, scale, bias, mean, variance}); | ||||
| output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // reserve | |||||
| output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
| // running mean/var | |||||
| if (param.fwd_mode == Param::FwdMode::INFERENCE) { | if (param.fwd_mode == Param::FwdMode::INFERENCE) { | ||||
| auto mark_empty_var = [&](VarNode *var) { | auto mark_empty_var = [&](VarNode *var) { | ||||
| var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | ||||
| @@ -92,9 +94,10 @@ BatchNormForward::BatchNormForward(VarNode *x, | |||||
| {x, scale, bias}} | {x, scale, bias}} | ||||
| { | { | ||||
| init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
| output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
| add_input({x, scale, bias}); | add_input({x, scale, bias}); | ||||
| output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // reserve | |||||
| output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
| auto mark_empty_var = [&](VarNode *var) { | auto mark_empty_var = [&](VarNode *var) { | ||||
| var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | ||||
| .add_flag(VarNode::Flag::VOLATILE_CONTENT); | .add_flag(VarNode::Flag::VOLATILE_CONTENT); | ||||
| @@ -151,7 +154,7 @@ BatchNormForward::do_make_node_prop() const { | |||||
| void BatchNormForward::scn_do_execute() { | void BatchNormForward::scn_do_execute() { | ||||
| auto &&x = input(0)->dev_tensor(); | auto &&x = input(0)->dev_tensor(); | ||||
| auto &&y = output(4)->dev_tensor(); | |||||
| auto &&y = output(5)->dev_tensor(); | |||||
| if (need_stats()) { | if (need_stats()) { | ||||
| auto &&o0 = output(0)->dev_tensor(), | auto &&o0 = output(0)->dev_tensor(), | ||||
| &&o1 = output(1)->dev_tensor(), | &&o1 = output(1)->dev_tensor(), | ||||
| @@ -192,9 +195,10 @@ void BatchNormForward::scn_do_execute() { | |||||
| } | } | ||||
| auto save_mean = output(2)->dev_tensor().as_megdnn(); | auto save_mean = output(2)->dev_tensor().as_megdnn(); | ||||
| auto save_variance = output(3)->dev_tensor().as_megdnn(); | auto save_variance = output(3)->dev_tensor().as_megdnn(); | ||||
| auto reserve = output(4)->dev_tensor().as_megdnn(); | |||||
| auto workspace = intl::get_megdnn_workspace_from_var(output().back()); | auto workspace = intl::get_megdnn_workspace_from_var(output().back()); | ||||
| megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, | megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, | ||||
| save_mean, save_variance, y.as_megdnn(), workspace); | |||||
| save_mean, save_variance, reserve, y.as_megdnn(), workspace); | |||||
| } | } | ||||
| void BatchNormForward::add_input_layout_constraint() { | void BatchNormForward::add_input_layout_constraint() { | ||||
| @@ -208,18 +212,25 @@ void BatchNormForward::get_output_var_shape( | |||||
| "expect input, scale and bias to be 4 dim tensor, but " | "expect input, scale and bias to be 4 dim tensor, but " | ||||
| "got input dim: %zu, scale dim: %zu, bias dim: %zu", | "got input dim: %zu, scale dim: %zu, bias dim: %zu", | ||||
| inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim); | inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim); | ||||
| size_t inp_c = inp_shape[0][1], | |||||
| scale_c = inp_shape[1][1], | |||||
| bias_c = inp_shape[2][1]; | |||||
| size_t channel_idx; | |||||
| if (param().param_dim == Param::ParamDim::DIM_111C) { | |||||
| channel_idx = 3; | |||||
| } else { | |||||
| channel_idx = 1; | |||||
| } | |||||
| size_t inp_c = inp_shape[0][channel_idx], | |||||
| scale_c = inp_shape[1][channel_idx], | |||||
| bias_c = inp_shape[2][channel_idx]; | |||||
| mgb_assert(inp_c == scale_c && inp_c == bias_c, | mgb_assert(inp_c == scale_c && inp_c == bias_c, | ||||
| "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu", | "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu", | ||||
| inp_c, scale_c, bias_c); | inp_c, scale_c, bias_c); | ||||
| out_shape[4] = inp_shape[0]; | |||||
| out_shape[5] = inp_shape[0]; | |||||
| for (size_t i = 0; i < 4; ++ i) { | for (size_t i = 0; i < 4; ++ i) { | ||||
| out_shape[i] = inp_shape[1]; | out_shape[i] = inp_shape[1]; | ||||
| } | } | ||||
| out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})}; | |||||
| if (!need_stats()) { | if (!need_stats()) { | ||||
| out_shape[0] = out_shape[1] = {0}; | out_shape[0] = out_shape[1] = {0}; | ||||
| } | } | ||||
| @@ -231,7 +242,7 @@ size_t BatchNormForward::get_workspace_size_bytes( | |||||
| #define in(x) {input_shapes[x], input(x)->dtype()} | #define in(x) {input_shapes[x], input(x)->dtype()} | ||||
| #define out(x) {output_shapes[x], output(x)->dtype()} | #define out(x) {output_shapes[x], output(x)->dtype()} | ||||
| return megdnn_opr()->get_workspace_in_bytes( | return megdnn_opr()->get_workspace_in_bytes( | ||||
| in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4)); | |||||
| in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4), out(5)); | |||||
| #undef in | #undef in | ||||
| #undef out | #undef out | ||||
| } | } | ||||
| @@ -249,7 +260,8 @@ void BatchNormForward::init_output_dtype() { | |||||
| for (size_t i = 2; i < nr_inp; ++ i) { | for (size_t i = 2; i < nr_inp; ++ i) { | ||||
| mgb_assert(input(1)->dtype() == input(i)->dtype()); | mgb_assert(input(1)->dtype() == input(i)->dtype()); | ||||
| } | } | ||||
| output(4)->dtype(input(0)->dtype()); | |||||
| output(4)->dtype(dtype::Byte()); // reserve | |||||
| output(5)->dtype(input(0)->dtype()); // output | |||||
| for (size_t i = 0; i < 4; ++ i) { | for (size_t i = 0; i < 4; ++ i) { | ||||
| output(i)->dtype(input(1)->dtype()); | output(i)->dtype(input(1)->dtype()); | ||||
| } | } | ||||
| @@ -271,9 +283,10 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { | |||||
| switch (opr.param().fwd_mode) { | switch (opr.param().fwd_mode) { | ||||
| case BatchNorm::Param::FwdMode::TRAINING: | case BatchNorm::Param::FwdMode::TRAINING: | ||||
| grad = BatchNormBackward::make( | grad = BatchNormBackward::make( | ||||
| opr.input(0), out_grad[4], | |||||
| opr.input(0), out_grad[5], | |||||
| opr.output(2), opr.output(3), | opr.output(2), opr.output(3), | ||||
| opr.input(1), opr.param()); | |||||
| opr.input(1), opr.output(4), // reserve | |||||
| opr.param()); | |||||
| for (size_t i = 0; i < 3; ++ i) { | for (size_t i = 0; i < 3; ++ i) { | ||||
| ret[i] = grad[(i + 2) % 3].node(); | ret[i] = grad[(i + 2) % 3].node(); | ||||
| } | } | ||||
| @@ -281,13 +294,13 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { | |||||
| case BatchNorm::Param::FwdMode::INFERENCE: | case BatchNorm::Param::FwdMode::INFERENCE: | ||||
| auto sqrt_var = PowC::make((SymbolVar{opr.input(4)} | auto sqrt_var = PowC::make((SymbolVar{opr.input(4)} | ||||
| + static_cast<dt_float32>(opr.param().epsilon)), 0.5, opr.config()); | + static_cast<dt_float32>(opr.param().epsilon)), 0.5, opr.config()); | ||||
| auto d_bn_scale_unreduced = SymbolVar{out_grad[4]} * | |||||
| auto d_bn_scale_unreduced = SymbolVar{out_grad[5]} * | |||||
| (SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var; | (SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var; | ||||
| auto d_bn_scale = Reduce::make(d_bn_scale_unreduced, | auto d_bn_scale = Reduce::make(d_bn_scale_unreduced, | ||||
| Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1))); | Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1))); | ||||
| auto d_bn_bias = Reduce::make(out_grad[4], | |||||
| auto d_bn_bias = Reduce::make(out_grad[5], | |||||
| Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2))); | Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2))); | ||||
| auto dx = SymbolVar{out_grad[4]} * SymbolVar{opr.input(1)} / sqrt_var; | |||||
| auto dx = SymbolVar{out_grad[5]} * SymbolVar{opr.input(1)} / sqrt_var; | |||||
| ret[0] = dx.node(); | ret[0] = dx.node(); | ||||
| ret[1] = d_bn_scale.node(); | ret[1] = d_bn_scale.node(); | ||||
| @@ -302,26 +315,26 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); | |||||
| BatchNormBackward::BatchNormBackward(VarNode *x, | BatchNormBackward::BatchNormBackward(VarNode *x, | ||||
| VarNode *y_grad, VarNode *save_mean, | VarNode *y_grad, VarNode *save_mean, | ||||
| VarNode* save_variance, VarNode *scale, | |||||
| VarNode* save_variance, VarNode *scale, VarNode *reserve, | |||||
| const Param ¶m, const OperatorNodeConfig &config): | const Param ¶m, const OperatorNodeConfig &config): | ||||
| Super({x->owner_graph(), config, "batch_norm_bwd", | Super({x->owner_graph(), config, "batch_norm_bwd", | ||||
| {x, y_grad, save_mean, save_variance, scale}}, | |||||
| {x, y_grad, save_mean, save_variance, scale, reserve}}, | |||||
| 0, true) | 0, true) | ||||
| { | { | ||||
| init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
| add_input({x, y_grad, save_mean, save_variance, scale}); | |||||
| add_input({x, y_grad, save_mean, save_variance, scale, reserve}); | |||||
| } | } | ||||
| SymbolVarArray BatchNormBackward::make(SymbolVar x, | SymbolVarArray BatchNormBackward::make(SymbolVar x, | ||||
| SymbolVar y_grad, SymbolVar save_mean, | SymbolVar y_grad, SymbolVar save_mean, | ||||
| SymbolVar save_variance, SymbolVar scale, | |||||
| SymbolVar save_variance, SymbolVar scale, SymbolVar reserve, | |||||
| const Param ¶m, | const Param ¶m, | ||||
| const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
| auto&& out = x.node() | auto&& out = x.node() | ||||
| ->owner_graph() | ->owner_graph() | ||||
| ->insert_opr(std::make_unique<BatchNormBackward>( | ->insert_opr(std::make_unique<BatchNormBackward>( | ||||
| x.node(), y_grad.node(), save_mean.node(), | x.node(), y_grad.node(), save_mean.node(), | ||||
| save_variance.node(), scale.node(), param, config)) | |||||
| save_variance.node(), scale.node(), reserve.node(), param, config)) | |||||
| ->output(); | ->output(); | ||||
| SymbolVarArray ret(out.size()); | SymbolVarArray ret(out.size()); | ||||
| for (size_t i = 0; i < ret.size(); i++) { | for (size_t i = 0; i < ret.size(); i++) { | ||||
| @@ -355,4 +368,11 @@ void BatchNormBackward::init_output_dtype() { | |||||
| output(2)->dtype(input(0)->dtype()); | output(2)->dtype(input(0)->dtype()); | ||||
| } | } | ||||
| cg::OperatorNodeBase::NodeProp* | |||||
| BatchNormBackward::do_make_node_prop() const { | |||||
| auto ret = Super::do_make_node_prop(); | |||||
| ret->add_dep_type_existing_var(input(5), | |||||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
| return ret; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -391,14 +391,14 @@ struct OprMaker<opr::BatchNorm, 0> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprMaker<opr::BatchNormBackward, 5> { | |||||
| struct OprMaker<opr::BatchNormBackward, 6> { | |||||
| using Param = opr::BatchNormBackward::Param; | using Param = opr::BatchNormBackward::Param; | ||||
| static cg::OperatorNodeBase* make(const Param& param, | static cg::OperatorNodeBase* make(const Param& param, | ||||
| const cg::VarNodeArray& i, | const cg::VarNodeArray& i, | ||||
| ComputingGraph& graph, | ComputingGraph& graph, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| MGB_MARK_USED_VAR(graph); | MGB_MARK_USED_VAR(graph); | ||||
| return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], param, | |||||
| return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, | |||||
| config)[0] | config)[0] | ||||
| .node() | .node() | ||||
| ->owner_opr(); | ->owner_opr(); | ||||
| @@ -576,7 +576,7 @@ using ConvBiasForwardV4 = ConvBiasForward; | |||||
| MGB_SEREG_OPR(ConvBiasForwardV4, 0); | MGB_SEREG_OPR(ConvBiasForwardV4, 0); | ||||
| MGB_SEREG_OPR(BatchNorm, 0); | MGB_SEREG_OPR(BatchNorm, 0); | ||||
| MGB_SEREG_OPR(BatchNormBackward, 5); | |||||
| MGB_SEREG_OPR(BatchNormBackward, 6); | |||||
| using LocalShareForwardV1 = LocalShareForward; | using LocalShareForwardV1 = LocalShareForward; | ||||
| using LocalShareBackwardDataV1 = LocalShareBackwardData; | using LocalShareBackwardDataV1 = LocalShareBackwardData; | ||||
| @@ -183,6 +183,10 @@ namespace { | |||||
| #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) | #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) | ||||
| #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | ||||
| #define _NR_INPUTS 6 | |||||
| #define _NR_OUTPUTS 3 | |||||
| #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) | |||||
| #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| /* ======================= MegDNNOprWrapperFwd ======================= */ | /* ======================= MegDNNOprWrapperFwd ======================= */ | ||||
| @@ -24,7 +24,7 @@ namespace opr { | |||||
| /* input: | /* input: | ||||
| * x, scale, bias, [running_mean, running_variance] | * x, scale, bias, [running_mean, running_variance] | ||||
| * output: | * output: | ||||
| * running_mean, running_variance, save_mean, save_inv_variance, y | |||||
| * running_mean, running_variance, save_mean, save_inv_variance, reserve, y | |||||
| * | * | ||||
| * All params have the same definition with cudnn batch normalization. | * All params have the same definition with cudnn batch normalization. | ||||
| * | * | ||||
| @@ -35,6 +35,9 @@ namespace opr { | |||||
| * | * | ||||
| * For statistic(mean and variance) update: | * For statistic(mean and variance) update: | ||||
| * running_mean = (1 - moving_average) * running_mean + moving_average * new_mean | * running_mean = (1 - moving_average) * running_mean + moving_average * new_mean | ||||
| * | |||||
| * Output reserve is used for cudnnBatchNormalizationForwardTrainingEx, and should | |||||
| * be preserved for backward. | |||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS(BatchNormForward, | MGB_DEFINE_OPR_CLASS(BatchNormForward, | ||||
| cg::OutshapePureByInshapeOpr< | cg::OutshapePureByInshapeOpr< | ||||
| @@ -86,7 +89,7 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, | |||||
| using BatchNorm = BatchNormForward; | using BatchNorm = BatchNormForward; | ||||
| /* input: | /* input: | ||||
| * x, y_grad, save_mean, save_inv_variance, scale | |||||
| * x, y_grad, save_mean, save_inv_variance, scale, reserve | |||||
| * output: | * output: | ||||
| * scale_grad, bias_grad, x_grad | * scale_grad, bias_grad, x_grad | ||||
| */ | */ | ||||
| @@ -97,15 +100,17 @@ MGB_DEFINE_OPR_CLASS(BatchNormBackward, | |||||
| public: | public: | ||||
| BatchNormBackward(VarNode *x, VarNode *y_grad, | BatchNormBackward(VarNode *x, VarNode *y_grad, | ||||
| VarNode *save_mean, VarNode *save_variance, | VarNode *save_mean, VarNode *save_variance, | ||||
| VarNode *scale, | |||||
| VarNode *scale, VarNode *reserve, | |||||
| const Param ¶m, | const Param ¶m, | ||||
| const OperatorNodeConfig &config); | const OperatorNodeConfig &config); | ||||
| static SymbolVarArray make(SymbolVar x, | static SymbolVarArray make(SymbolVar x, | ||||
| SymbolVar y_grad, SymbolVar save_mean, | SymbolVar y_grad, SymbolVar save_mean, | ||||
| SymbolVar save_variance, SymbolVar scale, | SymbolVar save_variance, SymbolVar scale, | ||||
| SymbolVar reserve, | |||||
| const Param ¶m = {}, | const Param ¶m = {}, | ||||
| const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
| private: | private: | ||||
| NodeProp* do_make_node_prop() const override; | |||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| void init_output_dtype() override; | void init_output_dtype() override; | ||||
| }; | }; | ||||
| @@ -95,13 +95,13 @@ SymbolVarArray batch_norm(const SymbolVarArray& inputs, const Param ¶m) { | |||||
| SymbolVarArray ret; | SymbolVarArray ret; | ||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], param); | ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], param); | ||||
| return {ret[4], ret[2], ret[3]}; | |||||
| return {ret[5], ret[2], ret[3]}; | |||||
| } | } | ||||
| else { | else { | ||||
| mgb_assert(inputs.size() == 5); | mgb_assert(inputs.size() == 5); | ||||
| ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], | ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], | ||||
| inputs[3], inputs[4], param); | inputs[3], inputs[4], param); | ||||
| return {ret[4], ret[0], ret[1]}; | |||||
| return {ret[5], ret[0], ret[1]}; | |||||
| } | } | ||||
| } | } | ||||