Browse Source

fix(imperative): add dtype promote support for concat

GitOrigin-RevId: e743a6c995
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
616352b009
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      imperative/src/impl/transformations/dtype_promote.cpp

+ 4
- 3
imperative/src/impl/transformations/dtype_promote.cpp View File

@@ -238,7 +238,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs);
}

ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes);

@@ -258,12 +258,13 @@ ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Concat>(naive_promote_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(convolution3d_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule);
register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
}
} register_helper;



Loading…
Cancel
Save