|
|
|
@@ -15,17 +15,24 @@ |
|
|
|
*/ |
|
|
|
#include "nnacl/infer/infer_register.h" |
|
|
|
|
|
|
|
InferShape g_infer_func[PrimType_MAX]; |
|
|
|
InferShape *g_infer_func; |
|
|
|
|
|
|
|
__attribute__((constructor(101))) void InitInferFuncBuf() { |
|
|
|
g_infer_func = malloc(PrimType_MAX * sizeof(InferShape)); |
|
|
|
if (g_infer_func != NULL) { |
|
|
|
memset(g_infer_func, 0, PrimType_MAX * sizeof(InferShape)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
InferShape GetInferFunc(int prim_type) { |
|
|
|
if (prim_type < PrimType_MAX) { |
|
|
|
if (g_infer_func != NULL && prim_type < PrimType_MAX) { |
|
|
|
return g_infer_func[prim_type]; |
|
|
|
} |
|
|
|
return NULL; |
|
|
|
} |
|
|
|
|
|
|
|
void RegInfer(int prim_type, InferShape func) { |
|
|
|
if (prim_type < PrimType_MAX) { |
|
|
|
if (g_infer_func != NULL && prim_type < PrimType_MAX) { |
|
|
|
g_infer_func[prim_type] = func; |
|
|
|
} |
|
|
|
} |