|
|
|
@@ -18,20 +18,20 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( |
|
|
|
const FormattedTensorValue& tensor, const FT& target, |
|
|
|
const std::string& scope) const { |
|
|
|
std::vector<int32_t> pattern; |
|
|
|
if (tensor.format() == FT::NHWC && target == FT::NCHW) { |
|
|
|
Format format = tensor.format(); |
|
|
|
if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { |
|
|
|
// FIXME(czh): temporary fast path for group conv 5D weight. |
|
|
|
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { |
|
|
|
pattern = {0, 1, 4, 2, 3}; |
|
|
|
} else { |
|
|
|
pattern = {0, 3, 1, 2}; |
|
|
|
} |
|
|
|
} else if (tensor.format() == FT::NCHW && target == FT::NHWC) { |
|
|
|
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { |
|
|
|
pattern = {0, 2, 3, 1}; |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, "Unsupport format conversion from %s to %s", |
|
|
|
tensor.format().to_string().c_str(), |
|
|
|
Format(target).to_string().c_str()); |
|
|
|
format.to_string().c_str(), Format(target).to_string().c_str()); |
|
|
|
} |
|
|
|
auto output = |
|
|
|
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; |
|
|
|
@@ -84,7 +84,7 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { |
|
|
|
return out; |
|
|
|
} else { |
|
|
|
mgb_throw( |
|
|
|
MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).", |
|
|
|
MegBrainError, "Unsupported shape ndim %lu in GetAttr(Shape).", |
|
|
|
shape.ndim); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -189,7 +189,7 @@ ValueRefList reshape_rule( |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} else { |
|
|
|
// will not maintain src's format |
|
|
|
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); |
|
|
|
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); |
|
|
|
auto outputs = imperative::apply(op, {nchw_src}); |
|
|
|
return t.wrap_outputs(outputs); |
|
|
|
} |
|
|
|
@@ -204,7 +204,7 @@ ValueRefList reshape_rule( |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} else { |
|
|
|
// will not maintain src's format |
|
|
|
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); |
|
|
|
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); |
|
|
|
return t.wrap_outputs(outputs); |
|
|
|
@@ -229,7 +229,7 @@ ValueRefList broadcast_rule( |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} else { |
|
|
|
// will not maintain src's format |
|
|
|
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); |
|
|
|
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); |
|
|
|
auto outputs = imperative::apply(op, {nchw_src}); |
|
|
|
return t.wrap_outputs(outputs); |
|
|
|
} |
|
|
|
@@ -244,7 +244,7 @@ ValueRefList broadcast_rule( |
|
|
|
return t.wrap_outputs(outputs, FT::NHWC); |
|
|
|
} else { |
|
|
|
// will not maintain src's format |
|
|
|
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); |
|
|
|
auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); |
|
|
|
auto outputs = imperative::apply( |
|
|
|
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); |
|
|
|
return t.wrap_outputs(outputs); |
|
|
|
@@ -323,7 +323,7 @@ ValueRefList setsubtensor_rule( |
|
|
|
auto nhwc_inputs = ValueRefList(inputs.size()); |
|
|
|
if (format == FT::DEFAULT || format == FT::NCHW) { |
|
|
|
// value for setsubtensor should transpose to match shape. |
|
|
|
auto nhwc_value = t.to(*(t.as(value, FT::NCHW)), FT::NHWC); |
|
|
|
auto nhwc_value = t.to(value, FT::NHWC); |
|
|
|
// make new inputs for setsubtensor |
|
|
|
nhwc_inputs[0] = src.value(); |
|
|
|
nhwc_inputs[1] = nhwc_value->value(); |
|
|
|
@@ -355,14 +355,15 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& |
|
|
|
return format; |
|
|
|
} |
|
|
|
|
|
|
|
inline ValueRefList unify_nhwc_inputs( |
|
|
|
Span<ValueRef>& inputs, std::string scope, const FormatTransformation& t) { |
|
|
|
inline ValueRefList unify_inputs_format( |
|
|
|
const Span<ValueRef>& inputs, const FT& dst_fmt, const std::string& scope, |
|
|
|
const FormatTransformation& t) { |
|
|
|
ValueRefList unified_inputs(inputs.size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto&& inp = inputs[i].cast(t.value_type()); |
|
|
|
if (inp.format() != FT::NHWC && |
|
|
|
if (inp.format() != dst_fmt && |
|
|
|
inp.value().shape().cast<ShapeValue>().ndim == 4) { |
|
|
|
unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope); |
|
|
|
unified_inputs[i] = t.to(inp, dst_fmt, scope); |
|
|
|
} else { |
|
|
|
unified_inputs[i] = inputs[i]; |
|
|
|
} |
|
|
|
@@ -375,7 +376,7 @@ ValueRefList elemwise_rule( |
|
|
|
const FormatTransformation& t) { |
|
|
|
FT format = get_inputs_format(inputs, t); |
|
|
|
if (format == FT::NHWC && auto_convert) { |
|
|
|
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); |
|
|
|
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); |
|
|
|
return t.wrap_outputs( |
|
|
|
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format); |
|
|
|
} |
|
|
|
@@ -389,7 +390,7 @@ ValueRefList concat_rule( |
|
|
|
if (!(format == FT::NHWC && auto_convert)) { |
|
|
|
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); |
|
|
|
} |
|
|
|
auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); |
|
|
|
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); |
|
|
|
// TODO: handle 5D NHWC Tensor from group conv |
|
|
|
auto axis = op.axis; |
|
|
|
if (axis == 2 || axis == 3) { |
|
|
|
@@ -460,6 +461,12 @@ ValueRefList adaptive_pooling_rule( |
|
|
|
#define FOREACH_FORMAT_POLICY_OP(cb) \ |
|
|
|
cb(Pooling) \ |
|
|
|
cb(Convolution) |
|
|
|
|
|
|
|
#define FOREACH_BYPASS_OP(cb) \ |
|
|
|
cb(ParamPackSplit) \ |
|
|
|
cb(ParamPackConcat) \ |
|
|
|
cb(CollectiveComm) \ |
|
|
|
cb(CheckNonFinite) |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
// multi inputs op without params |
|
|
|
@@ -517,6 +524,15 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE) |
|
|
|
} |
|
|
|
FOREACH_FORMAT_POLICY_OP(CREATE_FORMAT_POLICY_OP_RULE) |
|
|
|
|
|
|
|
#define CREATE_BYPASS_OP_RULE(Op) \ |
|
|
|
ValueRefList Op##_rule( \ |
|
|
|
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \ |
|
|
|
const FormatTransformation& t) { \ |
|
|
|
return t.wrap_outputs(imperative::apply(_op, t.unwrap_inputs(inputs))); \ |
|
|
|
} |
|
|
|
FOREACH_BYPASS_OP(CREATE_BYPASS_OP_RULE) |
|
|
|
#undef CREATE_BYPASS_OP_RULE |
|
|
|
|
|
|
|
#undef CREATE_FORMAT_OP_RULE |
|
|
|
#define REGISTER_OP_RULE(op) register_format_rule(op##_rule); |
|
|
|
struct FormatRuleRegistry { |
|
|
|
@@ -536,6 +552,7 @@ struct FormatRuleRegistry { |
|
|
|
FOREACH_IDENTITY_OP(REGISTER_OP_RULE) |
|
|
|
FOREACH_FORMAT_OP(REGISTER_OP_RULE) |
|
|
|
FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) |
|
|
|
FOREACH_BYPASS_OP(REGISTER_OP_RULE) |
|
|
|
} |
|
|
|
} _; |
|
|
|
#undef REGISTER_OP_RULE |
|
|
|
@@ -549,10 +566,13 @@ ValueRefList FormatTransformation::apply_transformation( |
|
|
|
if (iter != format_rules.end()) { |
|
|
|
return iter->second(apply_op->op(), inputs, m_auto_convert, *this); |
|
|
|
} else { |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); |
|
|
|
auto unified_inputs = unify_inputs_format( |
|
|
|
inputs, FT::DEFAULT, apply_op->op().scope(), *this); |
|
|
|
return wrap_outputs(imperative::apply(op, unwrap_inputs(unified_inputs))); |
|
|
|
} |
|
|
|
} else if (auto* create_tensor = op.as<CreateTensor>()) { |
|
|
|
auto format = create_tensor->format(); |
|
|
|
// TODO: add dimshuffle for nhwc format |
|
|
|
return {wrap_output(imperative::apply(op, inputs)[0], format)}; |
|
|
|
} else if (auto* get_attr = op.as<GetAttr>()) { |
|
|
|
auto&& input = inputs.item(); |
|
|
|
@@ -570,7 +590,7 @@ ValueRefList FormatTransformation::apply_transformation( |
|
|
|
return {ShapeValue::make(shape)}; |
|
|
|
} |
|
|
|
case GetAttr::Value: { |
|
|
|
auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); |
|
|
|
auto nchw_src = unwrap_input(to(src, FT::DEFAULT, "")); |
|
|
|
return imperative::apply(op, {nchw_src}); |
|
|
|
} |
|
|
|
default: |
|
|
|
|