Browse Source

fix(imperative): add dtype promote support for concat

GitOrigin-RevId: e743a6c995
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
616352b009
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      imperative/src/impl/transformations/dtype_promote.cpp

+ 4
- 3
imperative/src/impl/transformations/dtype_promote.cpp View File

@@ -238,7 +238,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }


ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes); mgb::DType target_dtype = get_promoted_dtype(dtypes);


@@ -258,12 +258,13 @@ ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
struct DTypePromoteRuleRegistry { struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() { DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule); register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Concat>(naive_promote_rule);
register_dtype_promote_rule<Reduce>(reduce_rule); register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule); register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(convolution3d_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule);
register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
} }
} register_helper; } register_helper;




Loading…
Cancel
Save