|
|
|
@@ -58,10 +58,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto&& tshp = inputs[1]; |
|
|
|
|
|
|
|
TensorLayout out_layout = src.layout; |
|
|
|
TensorShape out_shape; |
|
|
|
if (tshp.layout.ndim == 0 || tshp.value.empty()) { |
|
|
|
out_layout.ndim = 0; |
|
|
|
return {{{out_layout, src.comp_node}}, false}; |
|
|
|
out_shape.ndim = 0; |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
mgb_assert( |
|
|
|
tshp.layout.ndim == 1, |
|
|
|
@@ -69,17 +69,17 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
tshp.layout.ndim); |
|
|
|
|
|
|
|
size_t target_ndim = tshp.layout.shape[0]; |
|
|
|
out_layout.ndim = target_ndim; |
|
|
|
out_shape.ndim = target_ndim; |
|
|
|
auto* ptr = tshp.value.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < target_ndim; ++i) { |
|
|
|
out_layout.shape[i] = ptr[i]; |
|
|
|
out_shape[i] = ptr[i]; |
|
|
|
} |
|
|
|
mgb_assert(valid_broadcast(src.layout, out_layout), |
|
|
|
mgb_assert(valid_broadcast(src.layout, out_shape), |
|
|
|
"the input shape %s can not be broadcasted to target shape %s", |
|
|
|
src.layout.TensorShape::to_string().c_str(), |
|
|
|
out_layout.TensorShape::to_string().c_str()); |
|
|
|
src.layout.to_string().c_str(), |
|
|
|
out_shape.to_string().c_str()); |
|
|
|
|
|
|
|
return {{{out_layout, src.comp_node}}, true}; |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) |
|
|
|
@@ -108,10 +108,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
auto&& src = inputs[0]; |
|
|
|
auto&& tshp = inputs[1]; |
|
|
|
|
|
|
|
TensorLayout out_layout = src.layout; |
|
|
|
TensorShape out_shape; |
|
|
|
if (tshp.layout.ndim == 0 || tshp.value.empty()) { |
|
|
|
out_layout.ndim = 0; |
|
|
|
return {{{out_layout, src.comp_node}}, false}; |
|
|
|
out_shape.ndim = 0; |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
mgb_assert( |
|
|
|
tshp.layout.ndim == 1, |
|
|
|
@@ -119,31 +119,31 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
tshp.layout.ndim); |
|
|
|
|
|
|
|
size_t target_ndim = tshp.layout.shape[0]; |
|
|
|
out_layout.ndim = target_ndim; |
|
|
|
out_shape.ndim = target_ndim; |
|
|
|
auto* ptr = tshp.value.ptr<dt_int32>(); |
|
|
|
for (size_t i = 0; i < target_ndim; ++i) { |
|
|
|
out_layout.shape[i] = ptr[i]; |
|
|
|
out_shape[i] = ptr[i]; |
|
|
|
} |
|
|
|
|
|
|
|
if (src.layout.ndim == 0) { |
|
|
|
return {{{out_layout, src.comp_node}}, false}; |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; |
|
|
|
} |
|
|
|
|
|
|
|
if (op.axis != opr::Reshape::Param::INVALID_AXIS) { |
|
|
|
mgb_assert(out_layout.shape[op.axis] == -1); |
|
|
|
out_layout.shape[op.axis] = 1; |
|
|
|
mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0, |
|
|
|
mgb_assert(out_shape[op.axis] == -1); |
|
|
|
out_shape[op.axis] = 1; |
|
|
|
mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0, |
|
|
|
"can not reshape from %s to %s", |
|
|
|
src.layout.to_string().c_str(), |
|
|
|
out_layout.to_string().c_str()); |
|
|
|
out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems(); |
|
|
|
out_shape.to_string().c_str()); |
|
|
|
out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems(); |
|
|
|
} else { |
|
|
|
mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(), |
|
|
|
mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(), |
|
|
|
"can not reshape from %s to %s", |
|
|
|
src.layout.to_string().c_str(), |
|
|
|
out_layout.to_string().c_str()); |
|
|
|
out_shape.to_string().c_str()); |
|
|
|
} |
|
|
|
return {{{out_layout, src.comp_node}}, true}; |
|
|
|
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Reshape, Reshape) |
|
|
|
|