|
|
@@ -390,8 +390,8 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp |
|
|
this->ConvInferShape(input_h, input_w, &output_h, &output_w); |
|
|
this->ConvInferShape(input_h, input_w, &output_h, &output_w); |
|
|
|
|
|
|
|
|
std::vector<int> out_shape{input_tensor->shape()}; |
|
|
std::vector<int> out_shape{input_tensor->shape()}; |
|
|
out_shape.at(1) = output_h; |
|
|
|
|
|
out_shape.at(2) = output_w; |
|
|
|
|
|
|
|
|
out_shape.at(1) = output_h > 0 ? output_h : 1; |
|
|
|
|
|
out_shape.at(2) = output_w > 0 ? output_w : 1; |
|
|
out_shape.at(3) = weight_tensor->shape()[0]; |
|
|
out_shape.at(3) = weight_tensor->shape()[0]; |
|
|
out_tensor->set_shape(out_shape); |
|
|
out_tensor->set_shape(out_shape); |
|
|
|
|
|
|
|
|
|