|
|
|
@@ -15,7 +15,7 @@ |
|
|
|
*/ |
|
|
|
#include "nnacl/infer/infer_register.h" |
|
|
|
|
|
|
|
#ifdef SUPPORT_MSVC |
|
|
|
#ifdef _MSC_VER |
|
|
|
#include "nnacl/infer/adam_infer.h" |
|
|
|
#include "nnacl/infer/add_sub_grad_infer.h" |
|
|
|
#include "nnacl/infer/addn_infer.h" |
|
|
|
@@ -77,7 +77,6 @@ |
|
|
|
#include "nnacl/infer/matmul_infer.h" |
|
|
|
#include "nnacl/infer/max_min_grad_infer.h" |
|
|
|
#include "nnacl/infer/mean_infer.h" |
|
|
|
#include "nnacl/infer/control/merge_infer.h" |
|
|
|
#include "nnacl/infer/mfcc_infer.h" |
|
|
|
#include "nnacl/infer/non_max_suppression_infer.h" |
|
|
|
#include "nnacl/infer/one_hot_infer.h" |
|
|
|
@@ -117,7 +116,6 @@ |
|
|
|
#include "nnacl/infer/stack_infer.h" |
|
|
|
#include "nnacl/infer/strided_slice_grad_infer.h" |
|
|
|
#include "nnacl/infer/strided_slice_infer.h" |
|
|
|
#include "nnacl/infer/control/switch_infer.h" |
|
|
|
#include "nnacl/infer/control/tensorlist_fromtensor_infer.h" |
|
|
|
#include "nnacl/infer/control/tensorlist_getitem_infer.h" |
|
|
|
#include "nnacl/infer/control/tensorlist_reserve_infer.h" |
|
|
|
@@ -133,6 +131,15 @@ |
|
|
|
#include "nnacl/infer/unstack_infer.h" |
|
|
|
#include "nnacl/infer/where_infer.h" |
|
|
|
#include "nnacl/infer/while_infer.h" |
|
|
|
#include "nnacl/infer/split_with_over_lap_infer.h" |
|
|
|
#include "nnacl/infer/ragged_range_infer.h" |
|
|
|
#include "nnacl/infer/glu_infer.h" |
|
|
|
#include "nnacl/infer/control/tensor_array_read_infer.h" |
|
|
|
#include "nnacl/infer/control/tensor_array_infer.h" |
|
|
|
#include "nnacl/infer/control/tensor_array_write_infer.h" |
|
|
|
#include "nnacl/infer/affine_infer.h" |
|
|
|
#include "nnacl/infer/attention_infer.h" |
|
|
|
#include "nnacl/infer/scatter_nd_update_infer.h" |
|
|
|
|
|
|
|
InferShape g_infer_func[PrimType_MAX * sizeof(InferShape)] = {0}; |
|
|
|
void RegAllInferFunc1() { |
|
|
|
@@ -230,7 +237,7 @@ void RegAllInferFunc1() { |
|
|
|
g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape; |
|
|
|
g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape; |
|
|
|
g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape; |
|
|
|
g_infer_func[PrimType_Merge] = MergeInferShape; |
|
|
|
g_infer_func[PrimType_Merge] = NULL; |
|
|
|
g_infer_func[PrimType_Mfcc] = MfccInferShape; |
|
|
|
g_infer_func[PrimType_Minimum] = ArithmeticInferShape; |
|
|
|
g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape; |
|
|
|
@@ -293,7 +300,7 @@ void RegAllInferFunc2() { |
|
|
|
g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape; |
|
|
|
g_infer_func[PrimType_SubFusion] = ArithmeticInferShape; |
|
|
|
g_infer_func[PrimType_SubGrad] = AddSubGradInferShape; |
|
|
|
g_infer_func[PrimType_Switch] = SwitchInferShape; |
|
|
|
g_infer_func[PrimType_Switch] = NULL; |
|
|
|
g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape; |
|
|
|
g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape; |
|
|
|
g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape; |
|
|
|
@@ -334,16 +341,32 @@ void RegAllInferFunc2() { |
|
|
|
g_infer_func[PrimType_CumSum] = CumsumInferShape; |
|
|
|
} |
|
|
|
|
|
|
|
typedef void RegFunc(); |
|
|
|
#pragma data_seg(".CRT$XIU") |
|
|
|
static RegFunc *before[] = {RegAllInferFunc1, RegAllInferFunc2}; |
|
|
|
#pragma data_seg() |
|
|
|
void RegAllInferFunc3() { |
|
|
|
g_infer_func[PrimType_SplitWithOverlap] = SplitWithOverlapInferShape; |
|
|
|
g_infer_func[PrimType_GenOP] = NULL; |
|
|
|
g_infer_func[PrimType_RaggedRange] = RaggedRangeInferShape; |
|
|
|
g_infer_func[PrimType_GLU] = GluInferShape; |
|
|
|
g_infer_func[PrimType_TensorArray] = TensorArrayInferShape; |
|
|
|
g_infer_func[PrimType_TensorArrayRead] = TensorArrayReadInferShape; |
|
|
|
g_infer_func[PrimType_TensorArrayWrite] = TensorArrayWriteInferShape; |
|
|
|
g_infer_func[PrimType_Affine] = AffineInferShape; |
|
|
|
g_infer_func[PrimType_Attention] = AttentionInferShape; |
|
|
|
g_infer_func[PrimType_LSTMGrad] = NULL; |
|
|
|
g_infer_func[PrimType_ScatterNdUpdate] = ScatterNdUpdateInferShape; |
|
|
|
} |
|
|
|
|
|
|
|
#else |
|
|
|
__attribute__((init_priority(101))) InferShape g_infer_func[PrimType_MAX * sizeof(InferShape)] = {0}; |
|
|
|
#endif // SUPPORT_MSVC |
|
|
|
#endif // _MSC_VER |
|
|
|
|
|
|
|
InferShape GetInferFunc(int prim_type) { |
|
|
|
#ifdef _MSC_VER |
|
|
|
if (g_infer_func[PrimType_Abs] == NULL) { |
|
|
|
RegAllInferFunc1(); |
|
|
|
RegAllInferFunc2(); |
|
|
|
RegAllInferFunc3(); |
|
|
|
} |
|
|
|
#endif |
|
|
|
if (prim_type < PrimType_MAX) { |
|
|
|
return g_infer_func[prim_type]; |
|
|
|
} |
|
|
|
|