diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 4cdcb0b1..d39b6565 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -238,7 +238,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { return imperative::apply(op, inputs); } -ValueRefList convolution3d_rule(const OpDef& op, Span inputs) { +ValueRefList naive_promote_rule(const OpDef& op, Span inputs) { SmallVector dtypes = get_value_dtypes(inputs); mgb::DType target_dtype = get_promoted_dtype(dtypes); @@ -258,12 +258,13 @@ ValueRefList convolution3d_rule(const OpDef& op, Span inputs) { struct DTypePromoteRuleRegistry { DTypePromoteRuleRegistry() { register_dtype_promote_rule(elemwise_rule); + register_dtype_promote_rule(naive_promote_rule); register_dtype_promote_rule(reduce_rule); register_dtype_promote_rule(convolution_rule); register_dtype_promote_rule(convolution_backward_rule); register_dtype_promote_rule(batch_norm_rule); - register_dtype_promote_rule(convolution3d_rule); - register_dtype_promote_rule(convolution3d_rule); + register_dtype_promote_rule(naive_promote_rule); + register_dtype_promote_rule(naive_promote_rule); } } register_helper;