|
|
|
@@ -61,6 +61,44 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto&& ds = static_cast<const Dimshuffle&>(def); |
|
|
|
mgb_assert( |
|
|
|
ds.pattern.size() <= TensorShape::MAX_NDIM, |
|
|
|
"Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp); |
|
|
|
auto&& src = inputs[0]; |
|
|
|
TensorShape out_shape; |
|
|
|
if (src.layout.ndim == 0) { |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1; |
|
|
|
mgb_assert( |
|
|
|
src.layout.ndim == pattern_ndim, |
|
|
|
"input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim, |
|
|
|
src.layout.ndim); |
|
|
|
size_t idx = 0; |
|
|
|
bool input_used[TensorLayout::MAX_NDIM] = {0}; |
|
|
|
for (auto i : ds.pattern) { |
|
|
|
if (i < 0) { |
|
|
|
out_shape[idx] = 1; |
|
|
|
} else { |
|
|
|
input_used[i] = true; |
|
|
|
out_shape[idx] = src.layout.shape[i]; |
|
|
|
} |
|
|
|
++idx; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < pattern_ndim; ++i) { |
|
|
|
mgb_assert( |
|
|
|
input_used[i] || src.layout.shape[i] == 1, |
|
|
|
"non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd", |
|
|
|
src.layout.megdnn::TensorShape::to_string().c_str(), i); |
|
|
|
} |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs, |
|
|
|
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { |
|
|
|
@@ -110,6 +148,7 @@ OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) |
|
|
|
.make_from_op_node(make_from_op_node) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.fallback(); |
|
|
|
} // namespace dimshuffle |
|
|
|
} // namespace |
|
|
|
@@ -127,6 +166,22 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
return opr::AxisAddRemove::make(inputs[0], param, config); |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto&& op_def = def.cast_final_safe<AddAxis>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp); |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto olayout = src.layout; |
|
|
|
if (src.layout.ndim == 0) { |
|
|
|
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
for (auto&& i : op_def.axis) { |
|
|
|
olayout.add_axis_cont_inplace(i); |
|
|
|
} |
|
|
|
return {{{olayout, src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, const SmallVector<TensorPtr>& inputs, |
|
|
|
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { |
|
|
|
@@ -145,6 +200,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
OP_TRAIT_REG(AddAxis, AddAxis) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.fallback(); |
|
|
|
} // namespace add_axis |
|
|
|
} // namespace |
|
|
|
@@ -188,9 +244,37 @@ SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
return {Tensor::make(src->blob(), src->offset(), tlayout)}; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
auto&& op_def = def.cast_final_safe<RemoveAxis>(); |
|
|
|
size_t nr_inp = inputs.size(); |
|
|
|
mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp); |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto olayout = src.layout; |
|
|
|
if (src.layout.ndim == 0) { |
|
|
|
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
for (auto&& i : op_def.axis) { |
|
|
|
if (olayout.ndim == 1) { |
|
|
|
mgb_assert( |
|
|
|
olayout.shape[0] == 1 && i == 0, |
|
|
|
"can not remove axis %u from tensor of shape=%s", i, |
|
|
|
olayout.megdnn::TensorShape::to_string().c_str()); |
|
|
|
} else { |
|
|
|
mgb_assert( |
|
|
|
i < olayout.ndim && olayout.shape[i] == 1, |
|
|
|
"can not remove axis %u from tensor of shape=%s", i, |
|
|
|
olayout.megdnn::TensorShape::to_string().c_str()); |
|
|
|
olayout.remove_axis_inplace(i); |
|
|
|
} |
|
|
|
} |
|
|
|
return {{{olayout, src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(RemoveAxis, RemoveAxis) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.fallback(); |
|
|
|
} // namespace remove_axis |
|
|
|
} // namespace |
|
|
|
|