GitOrigin-RevId: 982dee36e1
tags/v0.6.0
| @@ -2049,16 +2049,17 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
| return new_opr; | |||
| } | |||
| }; | |||
| auto replace_concat_opr = [=](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| //! When input change and all input can convert to nchwxx, this opr will run | |||
| //! in nchwxx mode, else it will run in nchw mode, for example concat and | |||
| //! elemwise opr | |||
| auto replace_multi_inp_opr = [=](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| bool has_inp_changed = false; | |||
| bool can_exec_ncwxx = true; | |||
| for (size_t i = 0; i < opr->input().size(); i++) { | |||
| if (new_inp[i]->shape().ndim == 5) { | |||
| has_inp_changed = true; | |||
| break; | |||
| } else if (new_inp[i]->shape().ndim == 4) { | |||
| if (new_inp[i]->shape()[1] % pack_c_size != 0) { | |||
| can_exec_ncwxx = false; | |||
| @@ -2095,36 +2096,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
| } | |||
| }; | |||
| auto replace_elemwise_opr = [=](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| bool has_inp_changed = false; | |||
| for (size_t i = 0; i < opr->input().size(); i++) { | |||
| if (new_inp[i]->shape().ndim == 5) { | |||
| has_inp_changed = true; | |||
| break; | |||
| } | |||
| } | |||
| if (has_inp_changed) { | |||
| auto temp_inp = new_inp; | |||
| for (size_t i = 0; i < opr->input().size(); i++) { | |||
| if (new_inp[i]->shape().ndim == 4) { | |||
| auto new_var = RelayoutPlaceholder::make( | |||
| new_inp[i], src_to_nchwxx_mode); | |||
| temp_inp[i] = new_var.node(); | |||
| } else { | |||
| mgb_assert((new_inp[i]->shape().ndim == 5) || | |||
| new_inp[i]->shape().is_scalar()); | |||
| } | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, temp_inp, | |||
| opr->config()); | |||
| } else { | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| opr->config()); | |||
| } | |||
| }; | |||
| auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| @@ -2146,11 +2117,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | |||
| replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | |||
| replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||
| replace_func[opr::Concat::typeinfo()] = replace_concat_opr; | |||
| replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | |||
| replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | |||
| replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | |||
| replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | |||
| replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; | |||
| replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; | |||
| replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; | |||
| replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr; | |||
| replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr; | |||
| //! not support yet | |||
| replace_func[opr::ConvolutionBackwardData::typeinfo()] = | |||
| relayout_inp_to_nchw; | |||
| @@ -2164,6 +2135,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
| replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||
| relayout_inp_to_nchw; | |||
| replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | |||
| replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; | |||
| } | |||
| std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | |||
| @@ -2948,6 +2948,90 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
| } | |||
| TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) { | |||
| HostTensorGenerator<> gen; | |||
| auto cn = CompNode::load("cpu0"); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
| }; | |||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
| .rename(name); | |||
| }; | |||
| auto host_x1 = gen({1, 8, 16, 16}, cn); | |||
| auto host_x2 = gen({1, 1, 16, 16}, cn); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||
| opr::Convolution::Param param_conv; | |||
| param_conv.pad_h = param_conv.pad_w = 1; | |||
| auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||
| conv1 = opr::Convolution::make(x, w1, param_conv); | |||
| auto b = mkvar("b", {1, 1, 16, 16}), | |||
| y = opr::Elemwise::make({conv1 + b}, opr::Elemwise::Param::Mode::RELU); | |||
| SymbolVar y_opt; | |||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||
| options.enable_nchw44(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
| find_opr<opr::Convolution>(y_opt).param().format); | |||
| graph->compile({{y_opt, {}}}) | |||
| ->to_json() | |||
| ->writeto_fpath(output_file( | |||
| "TestGoptInference.ConvertFormatNCHW44MultiInput.json")); | |||
| HostTensorND host_y_opt, host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||
| make_callback_copy(y_opt, host_y_opt)}); | |||
| func->execute(); | |||
| //! meybe go to winograd in x86-32, so set error 1e-1 | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
| } | |||
| TEST(TestGoptInference, ConvertFormatNCHW44Reshape) { | |||
| HostTensorGenerator<> gen; | |||
| auto cn = CompNode::load("cpu0"); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
| .rename(name); | |||
| }; | |||
| auto host_x1 = gen({1, 8, 16, 16}, cn); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||
| opr::Convolution::Param param_conv; | |||
| param_conv.pad_h = param_conv.pad_w = 1; | |||
| auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||
| conv1 = opr::Convolution::make(x, w1, param_conv); | |||
| auto y = opr::Reshape::make(conv1, {8, 16 * 16}); | |||
| SymbolVar y_opt; | |||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||
| options.enable_nchw44(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
| find_opr<opr::Convolution>(y_opt).param().format); | |||
| graph->compile({{y_opt, {}}}) | |||
| ->to_json() | |||
| ->writeto_fpath(output_file( | |||
| "TestGoptInference.ConvertFormatNCHW44Reshape.json")); | |||
| HostTensorND host_y_opt, host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||
| make_callback_copy(y_opt, host_y_opt)}); | |||
| func->execute(); | |||
| //! meybe go to winograd in x86-32, so set error 1e-1 | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
| } | |||
| TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
| HostTensorGenerator<> gen; | |||
| auto cn = CompNode::load("cpu0"); | |||