Browse Source

fix(imperative/ops): improve infer_output_attrs for broadcast

GitOrigin-RevId: 6b7ed55769
tags/v1.3.0
Megvii Engine Team 5 years ago
parent
commit
fe1680b378
1 changed files with 23 additions and 23 deletions
  1. +23
    -23
      imperative/src/impl/ops/broadcast.cpp

+ 23
- 23
imperative/src/impl/ops/broadcast.cpp View File

@@ -58,10 +58,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto&& src = inputs[0];
auto&& tshp = inputs[1];

TensorLayout out_layout = src.layout;
TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, false};
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
@@ -69,17 +69,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim);

size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
out_shape[i] = ptr[i];
}
mgb_assert(valid_broadcast(src.layout, out_layout),
mgb_assert(valid_broadcast(src.layout, out_shape),
"the input shape %s can not be broadcasted to target shape %s",
src.layout.TensorShape::to_string().c_str(),
out_layout.TensorShape::to_string().c_str());
src.layout.to_string().c_str(),
out_shape.to_string().c_str());

return {{{out_layout, src.comp_node}}, true};
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
@@ -108,10 +108,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto&& src = inputs[0];
auto&& tshp = inputs[1];

TensorLayout out_layout = src.layout;
TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_layout.ndim = 0;
return {{{out_layout, src.comp_node}}, false};
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
@@ -119,31 +119,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim);

size_t target_ndim = tshp.layout.shape[0];
out_layout.ndim = target_ndim;
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
out_shape[i] = ptr[i];
}

if (src.layout.ndim == 0) {
return {{{out_layout, src.comp_node}}, false};
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}

if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(out_layout.shape[op.axis] == -1);
out_layout.shape[op.axis] = 1;
mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0,
mgb_assert(out_shape[op.axis] == -1);
out_shape[op.axis] = 1;
mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
"can not reshape from %s to %s",
src.layout.to_string().c_str(),
out_layout.to_string().c_str());
out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems();
out_shape.to_string().c_str());
out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
} else {
mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(),
mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(),
"can not reshape from %s to %s",
src.layout.to_string().c_str(),
out_layout.to_string().c_str());
out_shape.to_string().c_str());
}
return {{{out_layout, src.comp_node}}, true};
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

OP_TRAIT_REG(Reshape, Reshape)


Loading…
Cancel
Save