|
|
@@ -238,7 +238,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
return imperative::apply(op, inputs); |
|
|
return imperative::apply(op, inputs); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
|
|
|
|
|
|
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
SmallVector<DType> dtypes = get_value_dtypes(inputs); |
|
|
mgb::DType target_dtype = get_promoted_dtype(dtypes); |
|
|
mgb::DType target_dtype = get_promoted_dtype(dtypes); |
|
|
|
|
|
|
|
|
@@ -258,12 +258,13 @@ ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) { |
|
|
struct DTypePromoteRuleRegistry { |
|
|
struct DTypePromoteRuleRegistry { |
|
|
DTypePromoteRuleRegistry() { |
|
|
DTypePromoteRuleRegistry() { |
|
|
register_dtype_promote_rule<Elemwise>(elemwise_rule); |
|
|
register_dtype_promote_rule<Elemwise>(elemwise_rule); |
|
|
|
|
|
register_dtype_promote_rule<Concat>(naive_promote_rule); |
|
|
register_dtype_promote_rule<Reduce>(reduce_rule); |
|
|
register_dtype_promote_rule<Reduce>(reduce_rule); |
|
|
register_dtype_promote_rule<Convolution>(convolution_rule); |
|
|
register_dtype_promote_rule<Convolution>(convolution_rule); |
|
|
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); |
|
|
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); |
|
|
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); |
|
|
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); |
|
|
register_dtype_promote_rule<Convolution3D>(convolution3d_rule); |
|
|
|
|
|
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule); |
|
|
|
|
|
|
|
|
register_dtype_promote_rule<Convolution3D>(naive_promote_rule); |
|
|
|
|
|
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); |
|
|
} |
|
|
} |
|
|
} register_helper; |
|
|
} register_helper; |
|
|
|
|
|
|
|
|
|