|
|
|
@@ -176,13 +176,13 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out |
|
|
|
return RET_INFER_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int> out_shape; |
|
|
|
out_shape_.clear(); |
|
|
|
if (inputs_.size() == kDoubleNum) { |
|
|
|
auto shape_tensor = inputs_.at(1); |
|
|
|
if (shape_tensor->IsConst()) { |
|
|
|
if (shape_tensor->data_c() == nullptr || (shape_tensor->shape().size() == 1 && shape_tensor->shape()[0] == 0)) { |
|
|
|
MS_LOG(DEBUG) << "reshape to a scalar."; |
|
|
|
output->set_shape(out_shape); |
|
|
|
output->set_shape(out_shape_); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -194,23 +194,23 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out |
|
|
|
switch (shape_tensor->data_type()) { |
|
|
|
case kNumberTypeInt8: { |
|
|
|
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<int8_t>(data, inputs_, &out_shape, shape_size); |
|
|
|
CalShape<int8_t>(data, inputs_, &out_shape_, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeInt32: { |
|
|
|
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<int32_t>(data, inputs_, &out_shape, shape_size); |
|
|
|
CalShape<int32_t>(data, inputs_, &out_shape_, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeInt64: { |
|
|
|
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<int64_t>(data, inputs_, &out_shape, shape_size); |
|
|
|
CalShape<int64_t>(data, inputs_, &out_shape_, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeFloat: { |
|
|
|
auto data = reinterpret_cast<float *>(shape_tensor->MutableData()); |
|
|
|
CalShape<float>(data, inputs_, &out_shape, shape_size); |
|
|
|
CalShape<float>(data, inputs_, &out_shape_, shape_size); |
|
|
|
} break; |
|
|
|
case kNumberTypeUInt32: { |
|
|
|
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData()); |
|
|
|
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size); |
|
|
|
CalShape<uint32_t>(data, inputs_, &out_shape_, shape_size); |
|
|
|
} break; |
|
|
|
default: { |
|
|
|
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); |
|
|
|
@@ -219,18 +219,18 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out |
|
|
|
} |
|
|
|
} else if (inputs_.size() == kSingleNum) { |
|
|
|
for (size_t i = 0; i < GetShape().size(); ++i) { |
|
|
|
out_shape.push_back(GetShape().at(i)); |
|
|
|
out_shape_.push_back(GetShape().at(i)); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "inputs tensor size invalid."; |
|
|
|
return RET_INFER_ERR; |
|
|
|
} |
|
|
|
auto ret = CalNewShape(inputs_.front(), &out_shape); |
|
|
|
auto ret = CalNewShape(inputs_.front(), &out_shape_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "CalNewShape error"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
output->set_shape(out_shape); |
|
|
|
output->set_shape(out_shape_); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
} // namespace lite |
|
|
|
|