GitOrigin-RevId: d828512e44
tags/v0.5.0
| @@ -26,6 +26,7 @@ | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megdnn/tensor_format.h" | |||
| @@ -741,6 +742,19 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| return opr; | |||
| }; | |||
| auto replace_lsp_opr = [](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->same_type<opr::Linspace>()); | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| auto& lsp_opr = opr->cast_final_safe<opr::Linspace>(); | |||
| if (lsp_opr.output(0)->dtype() != dtype::Float16()) { | |||
| auto cvt_var = | |||
| opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {}); | |||
| return cvt_var.node()->owner_opr(); | |||
| } | |||
| return opr; | |||
| }; | |||
| auto replace_conv_opr = [use_f32_comp](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| @@ -778,6 +792,29 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| return new_matmul_opr.node()->owner_opr(); | |||
| }; | |||
| auto replace_batched_matmul_opr = [use_f32_comp]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>(); | |||
| auto new_param = matmul_opr.param(); | |||
| if (use_f32_comp) { | |||
| new_param.compute_mode = | |||
| megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
| } | |||
| mgb_assert(new_inp[0]->dtype() == dtype::Float16(), | |||
| "inp %s:%s, owner_opr:%s", new_inp[0]->dtype().name(), | |||
| new_inp[0]->name().c_str(), | |||
| new_inp[0]->owner_opr()->name().c_str()); | |||
| mgb_assert(new_inp[1]->dtype() == dtype::Float16(), | |||
| "inp %s:%s, owner_opr:%s", new_inp[1]->dtype().name(), | |||
| new_inp[1]->name().c_str(), | |||
| new_inp[1]->owner_opr()->name().c_str()); | |||
| auto new_matmul_opr = opr::BatchedMatrixMul::make( | |||
| new_inp[0], new_inp[1], new_param, matmul_opr.config()); | |||
| return new_matmul_opr.node()->owner_opr(); | |||
| }; | |||
| auto replace_reduce_opr = [use_f32_comp](OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| auto& reduce_opr = opr->cast_final_safe<opr::Reduce>(); | |||
| @@ -871,6 +908,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | |||
| VarReplaceCheckFlag::CHECK_DTYPE); | |||
| auto&& replace_func = ret->m_opr_replace_func; | |||
| replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr; | |||
| replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr; | |||
| replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr; | |||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | |||
| @@ -880,6 +918,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||
| replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr; | |||
| replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; | |||
| replace_func[opr::Remap::typeinfo()] = replace_remap_opr; | |||
| replace_func[opr::BatchedMatrixMul::typeinfo()] = | |||
| replace_batched_matmul_opr; | |||
| return ret; | |||
| #endif | |||
| } | |||
| @@ -27,6 +27,8 @@ | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "./helper.h" | |||
| @@ -892,6 +894,67 @@ TEST(TestGoptInference, Float32TOFloat16EndpointElemwise) { | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
| } | |||
| TEST(TestGoptInference, Float32TOFloat16Linspace) { | |||
| CompNode cn = CompNode::load("cpu0"); | |||
| HostTensorGenerator<> gen(0, 1, 0); | |||
| auto host_x = gen({3, 1}, cn); | |||
| auto graph = ComputingGraph::make(); | |||
| auto make_f32_to_f16_graph = [&]() { | |||
| graph->options().graph_opt_level = 0; | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
| auto xshp = opr::GetVarShape::make(x); | |||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
| auto sub = [&xshp, &cv](int idx) { | |||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
| }; | |||
| auto lin = opr::Linspace::make(cv(0), sub(0) - 1, sub(0), {}, {}); | |||
| auto shp = opr::Concat::make({sub(1), sub(0)}, 0); | |||
| auto y = opr::Reshape::make(lin, shp); | |||
| auto mm = opr::MatrixMul::make(x, y); | |||
| SymbolVar mm_opt; | |||
| unpack_vector(gopt::optimize_for_inference( | |||
| {mm}, gopt::OptimizeForInferenceOptions{} | |||
| .enable_f16_io_comp()), | |||
| mm_opt); | |||
| return mm_opt; | |||
| }; | |||
| auto make_f16_graph = [&]() { | |||
| auto x = opr::TypeCvt::make(opr::Host2DeviceCopy::make(*graph, host_x), | |||
| dtype::Float16()); | |||
| auto xshp = opr::GetVarShape::make(x); | |||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
| auto sub = [&xshp, &cv](int idx) { | |||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
| }; | |||
| auto lin = opr::Linspace::make(cv(0), sub(0) - 1, sub(0), {}, {}); | |||
| lin = opr::TypeCvt::make(lin, dtype::Float16()); | |||
| auto shp = opr::Concat::make({sub(1), sub(0)}, 0); | |||
| auto y = opr::Reshape::make(lin, shp); | |||
| auto mm = opr::MatrixMul::make(x, y); | |||
| mm = opr::TypeCvt::make(mm, dtype::Float32{}); | |||
| return mm; | |||
| }; | |||
| auto y_opt = make_f32_to_f16_graph(); | |||
| auto y = make_f16_graph(); | |||
| ASSERT_EQ(y_opt.dtype(), dtype::Float32{}); | |||
| ASSERT_EQ(y.dtype(), dtype::Float32{}); | |||
| HostTensorND host_y_opt, host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||
| make_callback_copy(y_opt, host_y_opt)}); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
| } | |||
| TEST(TestGoptInference, ConvertFormatNHWCD4) { | |||
| // hwcd4 is only supported in naive handle | |||
| NaiveMegDNNHandleScope naive_megdnn_handle; | |||