|
|
|
@@ -62,21 +62,32 @@ Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreat |
|
|
|
|
|
|
|
namespace { |
|
|
|
constexpr int kBroadcastToInputNum = 1; |
|
|
|
constexpr int kBroadcastToOnnxInputNum = 2; |
|
|
|
constexpr int kBroadcastToOutputNum = 1; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { |
|
|
|
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { |
|
|
|
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); |
|
|
|
if (inputs.size() != kBroadcastToInputNum && inputs.size() != kBroadcastToOnnxInputNum) { |
|
|
|
MS_LOG(ERROR) << "input size:" << inputs.size(); |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
if (outputs.size() != kBroadcastToOutputNum) { |
|
|
|
MS_LOG(ERROR) << "output size:" << outputs.size(); |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
auto input = inputs.at(0); |
|
|
|
outputs[0]->SetFormat(input->GetFormat()); |
|
|
|
outputs[0]->set_data_type(input->data_type()); |
|
|
|
if (!GetInferFlag()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end()); |
|
|
|
std::vector<int32_t> dst_shape(GetDstShape()); |
|
|
|
for (size_t i = 0; i < dst_shape.size(); ++i) { |
|
|
|
if (dst_shape[i] == -1) { |
|
|
|
dst_shape[i] = inputs[0]->shape()[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
auto input_shape = input->shape(); |
|
|
|
std::vector<int> shape(dst_shape.size()); |
|
|
|
int input_shape_index = input_shape.size() - 1; |
|
|
|
|