diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 681ad132..995149ec 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -61,6 +61,44 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); } +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& ds = static_cast(def); + mgb_assert( + ds.pattern.size() <= TensorShape::MAX_NDIM, + "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + TensorShape out_shape; + if (src.layout.ndim == 0) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; + } + size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1; + mgb_assert( + src.layout.ndim == pattern_ndim, + "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim, + src.layout.ndim); + size_t idx = 0; + bool input_used[TensorLayout::MAX_NDIM] = {0}; + for (auto i : ds.pattern) { + if (i < 0) { + out_shape[idx] = 1; + } else { + input_used[i] = true; + out_shape[idx] = src.layout.shape[i]; + } + ++idx; + } + for (size_t i = 0; i < pattern_ndim; ++i) { + mgb_assert( + input_used[i] || src.layout.shape[i] == 1, + "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd", + src.layout.megdnn::TensorShape::to_string().c_str(), i); + } + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; +} + SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { @@ -110,6 +148,7 @@ OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace dimshuffle } // namespace @@ -127,6 +166,22 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return opr::AxisAddRemove::make(inputs[0], param, config); } +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto olayout = src.layout; + if (src.layout.ndim == 0) { + return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; + } + for (auto&& i : op_def.axis) { + olayout.add_axis_cont_inplace(i); + } + return {{{olayout, src.comp_node}}, true}; +} + SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { @@ -145,6 +200,7 @@ SmallVector apply_on_physical_tensor( OP_TRAIT_REG(AddAxis, AddAxis) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace add_axis } // namespace @@ -188,9 +244,37 @@ SmallVector apply_on_physical_tensor( return {Tensor::make(src->blob(), src->offset(), tlayout)}; } +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp); + auto&& src = inputs[0]; + auto olayout = src.layout; + if (src.layout.ndim == 0) { + return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; + } + for (auto&& i : op_def.axis) { + if (olayout.ndim == 1) { + mgb_assert( + olayout.shape[0] == 1 && i == 0, + "can not remove axis %u from tensor of shape=%s", i, + olayout.megdnn::TensorShape::to_string().c_str()); + } else { + mgb_assert( + i < olayout.ndim && olayout.shape[i] == 1, + "can not remove axis %u from tensor of shape=%s", i, + olayout.megdnn::TensorShape::to_string().c_str()); + olayout.remove_axis_inplace(i); + } + } + return {{{olayout, src.comp_node}}, true}; +} + OP_TRAIT_REG(RemoveAxis, RemoveAxis) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace remove_axis } // namespace