| @@ -313,7 +313,6 @@ protected: | |||||
| size_t workspace_in_bytes); | size_t workspace_in_bytes); | ||||
| }; | }; | ||||
| using Cumsum = CumsumForward; | using Cumsum = CumsumForward; | ||||
| // mxx can be max or min | // mxx can be max or min | ||||
| class ArgmxxBase : public OperatorBase { | class ArgmxxBase : public OperatorBase { | ||||
| DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); | DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); | ||||
| @@ -48,6 +48,19 @@ MODES = { | |||||
| "H_SWISH", | "H_SWISH", | ||||
| "SILU", | "SILU", | ||||
| "GELU", | "GELU", | ||||
| "SINH", | |||||
| "COSH", | |||||
| "ASINH", | |||||
| "ACOSH", | |||||
| "ATANH", | |||||
| "TAN", | |||||
| "SOFTPLUS", | |||||
| "RELU6", | |||||
| "HSIGMOID", | |||||
| "LOGSIGMOID", | |||||
| "SQRT", | |||||
| "SQUARE", | |||||
| "SIGN", | |||||
| ], | ], | ||||
| 2: [ | 2: [ | ||||
| "ABS_GRAD", | "ABS_GRAD", | ||||
| @@ -76,8 +89,15 @@ MODES = { | |||||
| "FUSE_ADD_H_SWISH", | "FUSE_ADD_H_SWISH", | ||||
| "SILU_GRAD", | "SILU_GRAD", | ||||
| "GELU_GRAD", | "GELU_GRAD", | ||||
| "PRELU", | |||||
| "ASINH_GRAD", | |||||
| "ACOSH_GRAD", | |||||
| "ATANH_GRAD", | |||||
| "SOFTPLUS_GRAD", | |||||
| "RELU6_GRAD", | |||||
| "HSIGMOID_GRAD", | |||||
| ], | ], | ||||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3", "CLIP", "PRELU_GRAD"], | |||||
| } | } | ||||
| QINT4_MODES = { | QINT4_MODES = { | ||||
| @@ -107,8 +127,9 @@ QINT4_MODES = { | |||||
| "FUSE_ADD_TANH", | "FUSE_ADD_TANH", | ||||
| "FUSE_ADD_SIGMOID", | "FUSE_ADD_SIGMOID", | ||||
| "FUSE_ADD_H_SWISH", | "FUSE_ADD_H_SWISH", | ||||
| "PRELU", | |||||
| ], | ], | ||||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3", "CLIP"], | |||||
| } | } | ||||
| QINT32_MODES = { | QINT32_MODES = { | ||||
| @@ -12,7 +12,7 @@ DTYPES = { | |||||
| } | } | ||||
| MODES = { | MODES = { | ||||
| (1, "INT"): ["RELU", "ABS", "NEGATE"], | |||||
| (1, "INT"): ["RELU", "ABS", "NEGATE", "RELU6", "SQUARE", "SIGN"], | |||||
| (2, "INT"): [ | (2, "INT"): [ | ||||
| "ABS_GRAD", | "ABS_GRAD", | ||||
| "ADD", | "ADD", | ||||
| @@ -32,8 +32,9 @@ MODES = { | |||||
| "SHL", | "SHL", | ||||
| "SHR", | "SHR", | ||||
| "RMULH", | "RMULH", | ||||
| "PRELU", | |||||
| ], | ], | ||||
| (3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"], | |||||
| (3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV", "CLIP"], | |||||
| (1, "FLOAT"): [ | (1, "FLOAT"): [ | ||||
| "RELU", | "RELU", | ||||
| "ABS", | "ABS", | ||||
| @@ -59,6 +60,19 @@ MODES = { | |||||
| "H_SWISH", | "H_SWISH", | ||||
| "SILU", | "SILU", | ||||
| "GELU", | "GELU", | ||||
| "SINH", | |||||
| "COSH", | |||||
| "ASINH", | |||||
| "ACOSH", | |||||
| "ATANH", | |||||
| "TAN", | |||||
| "SOFTPLUS", | |||||
| "RELU6", | |||||
| "HSIGMOID", | |||||
| "LOGSIGMOID", | |||||
| "SQRT", | |||||
| "SQUARE", | |||||
| "SIGN", | |||||
| ], | ], | ||||
| (2, "FLOAT"): [ | (2, "FLOAT"): [ | ||||
| "ABS_GRAD", | "ABS_GRAD", | ||||
| @@ -87,8 +101,21 @@ MODES = { | |||||
| "FUSE_ADD_H_SWISH", | "FUSE_ADD_H_SWISH", | ||||
| "SILU_GRAD", | "SILU_GRAD", | ||||
| "GELU_GRAD", | "GELU_GRAD", | ||||
| "PRELU", | |||||
| "ASINH_GRAD", | |||||
| "ACOSH_GRAD", | |||||
| "ATANH_GRAD", | |||||
| "SOFTPLUS_GRAD", | |||||
| "RELU6_GRAD", | |||||
| "HSIGMOID_GRAD", | |||||
| ], | |||||
| (3, "FLOAT"): [ | |||||
| "COND_LEQ_MOV", | |||||
| "COND_LT_MOV", | |||||
| "FUSE_MUL_ADD3", | |||||
| "CLIP", | |||||
| "PRELU_GRAD", | |||||
| ], | ], | ||||
| (3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
| (1, "BOOL"): ["NOT"], | (1, "BOOL"): ["NOT"], | ||||
| (2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], | (2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], | ||||
| (3, "BOOL"): [], | (3, "BOOL"): [], | ||||
| @@ -424,6 +424,28 @@ pdef('Elemwise').add_enum( | |||||
| Doc('NEQ = 61', 'binary: x != y'), | Doc('NEQ = 61', 'binary: x != y'), | ||||
| Doc('ISNAN = 62', 'unary: isnan(x)'), | Doc('ISNAN = 62', 'unary: isnan(x)'), | ||||
| Doc('ISINF = 63', 'unary: isinf(x)'), | Doc('ISINF = 63', 'unary: isinf(x)'), | ||||
| Doc('SINH = 64', 'unary: sinh(x)'), | |||||
| Doc('COSH = 65', 'unary: cosh(x)'), | |||||
| Doc('ASINH = 66', 'unary: asinh(x)'), | |||||
| Doc('ACOSH = 67', 'unary: acosh(x)'), | |||||
| Doc('ATANH = 68', 'unary: atanh(x)'), | |||||
| Doc('TAN = 69', 'unary: tan(x)'), | |||||
| Doc('ASINH_GRAD = 70', 'binary: y / sqrt(x^2 + 1)'), | |||||
| Doc('ACOSH_GRAD = 71', 'binary: y / sqrt(x^2 - 1) (x > 1)'), | |||||
| Doc('ATANH_GRAD = 72', 'binary: y / (1 - x^2) (|x| < 1)'), | |||||
| Doc('PRELU = 73', 'binary: x > 0 ? x : x * y'), | |||||
| Doc('CLIP = 74', 'ternary: x <= y ? y : (x <= z ? x : z)'), | |||||
| Doc('PRELU_GRAD = 75', 'ternary: x > 0 ? y : y * z'), | |||||
| Doc('SOFTPLUS = 76', 'unary: log(1 + e^x)'), | |||||
| Doc('SOFTPLUS_GRAD = 77', 'binary: y * e^x / (1 + e^x)'), | |||||
| Doc('RELU6 = 78', 'unary: min(max(0, x), 6)'), | |||||
| Doc('RELU6_GRAD = 79', 'binary: x < 0 ? 0 : (x > 6 ? 0 : y)'), | |||||
| Doc('HSIGMOID = 80', 'unary: relu6(x + 3) / 6'), | |||||
| Doc('HSIGMOID_GRAD = 81', 'binary: x < -3 ? 0 : (x > 3 ? 0 : y / 6)'), | |||||
| Doc('LOGSIGMOID = 82', 'unary: -log(1 + e^(-x))'), | |||||
| Doc('SQRT = 83', 'unary: x^(1/2)'), | |||||
| Doc('SQUARE = 84', 'unary: x^2'), | |||||
| Doc('SIGN = 85', 'unary: sgn(x)'), | |||||
| ) | ) | ||||
| pdef('ElemwiseMultiType').add_enum( | pdef('ElemwiseMultiType').add_enum( | ||||
| @@ -25,12 +25,28 @@ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | ||||
| @@ -66,7 +82,14 @@ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | ||||
| @@ -86,15 +109,19 @@ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) | ||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| @@ -154,11 +154,18 @@ struct ElemwiseKern; | |||||
| // int and float | // int and float | ||||
| DEF_KERN_ALL(NEGATE, -x); | DEF_KERN_ALL(NEGATE, -x); | ||||
| DEF_KERN_ALL(SQUARE, x* x); | |||||
| #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) | #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) | ||||
| DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); | DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); | ||||
| DEF_KERN_INT(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6))); | |||||
| DEF_KERN_INT(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0))); | |||||
| DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); | DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); | ||||
| DEF_KERN_FLOAT(RELU6, x <= 6.f ? ctype(0) : (x <= 6.f ? x : ctype(6))); | |||||
| DEF_KERN_FLOAT(SIGN, x < 0.f ? -1.f : (x > 0.f ? 1.f : 0.f)); | |||||
| #else | #else | ||||
| DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); | DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); | ||||
| DEF_KERN_ALL(RELU6, x <= ctype(0) ? ctype(0) : (x <= ctype(6) ? x : ctype(6))); | |||||
| DEF_KERN_ALL(SIGN, x < ctype(0) ? ctype(-1) : (x > ctype(0) ? ctype(1) : ctype(0))); | |||||
| #endif | #endif | ||||
| DEF_KERN_INT(ABS, abs(int(x))); | DEF_KERN_INT(ABS, abs(int(x))); | ||||
| // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); | // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); | ||||
| @@ -186,6 +193,18 @@ DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); | |||||
| DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); | DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); | ||||
| DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); | DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); | ||||
| DEF_KERN_FLOAT(GELU, x* normcdf(x)); | DEF_KERN_FLOAT(GELU, x* normcdf(x)); | ||||
| DEF_KERN_FLOAT(SINH, sinhf(x)); | |||||
| DEF_KERN_FLOAT(COSH, coshf(x)); | |||||
| DEF_KERN_FLOAT(ASINH, asinhf(x)); | |||||
| DEF_KERN_FLOAT(ACOSH, acoshf(x)); | |||||
| DEF_KERN_FLOAT(ATANH, atanhf(x)); | |||||
| DEF_KERN_FLOAT(TAN, tanf(x)); | |||||
| DEF_KERN_FLOAT(SOFTPLUS, log1pf(expf(-fabsf(x))) + (x <= ctype(0) ? ctype(0) : x)); | |||||
| DEF_KERN_FLOAT( | |||||
| HSIGMOID, | |||||
| x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(1) : ((x + 3.f) / 6.f))); | |||||
| DEF_KERN_FLOAT(SQRT, sqrtf(x)); | |||||
| DEF_KERN_FLOAT(LOGSIGMOID, -log1pf(expf(-fabsf(x))) + (x >= ctype(0) ? ctype(0) : x)); | |||||
| // int only | // int only | ||||
| DEF_KERN(dt_bool, NOT, x ^ 1); | DEF_KERN(dt_bool, NOT, x ^ 1); | ||||
| @@ -240,6 +259,12 @@ DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y)); | |||||
| #else | #else | ||||
| DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); | DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); | ||||
| #endif | #endif | ||||
| #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) | |||||
| DEF_KERN_INT(PRELU, x > ctype(0) ? x : (x * y)); | |||||
| DEF_KERN_FLOAT(PRELU, x > 0.f ? x : (x * y)); | |||||
| #else | |||||
| DEF_KERN_ALL(PRELU, x > ctype(0) ? x : (x * y)); | |||||
| #endif | |||||
| // float only | // float only | ||||
| DEF_KERN_FLOAT(TRUE_DIV, x / y); | DEF_KERN_FLOAT(TRUE_DIV, x / y); | ||||
| @@ -259,6 +284,14 @@ DEF_KERN_FLOAT( | |||||
| DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); | DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); | ||||
| DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); | DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); | ||||
| DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); | DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); | ||||
| DEF_KERN_FLOAT(ASINH_GRAD, y / sqrt(x * x + 1.f)); | |||||
| DEF_KERN_FLOAT(ACOSH_GRAD, y / sqrt(x * x - 1.f)); | |||||
| DEF_KERN_FLOAT(ATANH_GRAD, y / (1.f - x * x)); | |||||
| DEF_KERN_FLOAT(SOFTPLUS_GRAD, y* expf(x) / (1.f + expf(x))); | |||||
| DEF_KERN_FLOAT(RELU6_GRAD, x <= ctype(0) ? ctype(0) : (x >= ctype(6) ? ctype(0) : y)); | |||||
| DEF_KERN_FLOAT( | |||||
| HSIGMOID_GRAD, | |||||
| x <= ctype(-3) ? ctype(0) : (x >= ctype(3) ? ctype(0) : (y / 6.f))); | |||||
| #undef KERN_SIG | #undef KERN_SIG | ||||
| /* ================== ternary kernels ================== */ | /* ================== ternary kernels ================== */ | ||||
| @@ -268,6 +301,8 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); | |||||
| DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); | DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); | ||||
| DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); | DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); | ||||
| DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | ||||
| DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z)); | |||||
| DEF_KERN_FLOAT(PRELU_GRAD, x >= 0.f ? y : (y * z)); | |||||
| #undef KERN_SIG | #undef KERN_SIG | ||||
| @@ -62,6 +62,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | ||||
| cb(NEQ); | |||||
| cb(ISNAN); | |||||
| cb(ISINF); | |||||
| #undef cb | #undef cb | ||||
| #define cb(_m) \ | #define cb(_m) \ | ||||
| @@ -84,11 +87,14 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); | ||||
| cb(ISNAN); | |||||
| cb(ISINF); | |||||
| #undef _a | #undef _a | ||||
| #define _a 2 | #define _a 2 | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); | ||||
| cb(NEQ); | |||||
| #undef _a | #undef _a | ||||
| #define _a 3 | #define _a 3 | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | ||||
| @@ -223,6 +229,28 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| CB_MODE(Mode::GELU); | CB_MODE(Mode::GELU); | ||||
| CB_MODE(Mode::GELU_GRAD); | CB_MODE(Mode::GELU_GRAD); | ||||
| CB_MODE(Mode::COND_LT_MOV); | CB_MODE(Mode::COND_LT_MOV); | ||||
| CB_MODE(Mode::SINH); | |||||
| CB_MODE(Mode::COSH); | |||||
| CB_MODE(Mode::ASINH); | |||||
| CB_MODE(Mode::ACOSH); | |||||
| CB_MODE(Mode::ATANH); | |||||
| CB_MODE(Mode::TAN); | |||||
| CB_MODE(Mode::ASINH_GRAD); | |||||
| CB_MODE(Mode::ACOSH_GRAD); | |||||
| CB_MODE(Mode::ATANH_GRAD); | |||||
| CB_MODE(Mode::PRELU); | |||||
| CB_MODE(Mode::PRELU_GRAD); | |||||
| CB_MODE(Mode::CLIP); | |||||
| CB_MODE(Mode::SOFTPLUS); | |||||
| CB_MODE(Mode::SOFTPLUS_GRAD); | |||||
| CB_MODE(Mode::RELU6); | |||||
| CB_MODE(Mode::RELU6_GRAD); | |||||
| CB_MODE(Mode::HSIGMOID); | |||||
| CB_MODE(Mode::HSIGMOID_GRAD); | |||||
| CB_MODE(Mode::LOGSIGMOID); | |||||
| CB_MODE(Mode::SQRT); | |||||
| CB_MODE(Mode::SQUARE); | |||||
| CB_MODE(Mode::SIGN); | |||||
| default: | default: | ||||
| megdnn_assert( | megdnn_assert( | ||||
| 0, | 0, | ||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOGSIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(PRELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU6, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQRT, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SQUARE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TAN, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOSH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASINH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||