|
|
|
@@ -83,17 +83,23 @@ class BroadcastOpGpuKernel : public GpuKernel { |
|
|
|
rhs_shape_.resize(MAX_DIMS, 1); |
|
|
|
output_shape_.resize(MAX_DIMS, 1); |
|
|
|
for (size_t i = 0; i < shape3.size(); i++) { |
|
|
|
output_shape_[i] = shape3[i]; |
|
|
|
if (need_broadcast_) { |
|
|
|
output_shape_[i] = shape3[i]; |
|
|
|
} |
|
|
|
output_num_ *= shape3[i]; |
|
|
|
} |
|
|
|
int lhs_offset = shape3.size() - shape1.size(); |
|
|
|
for (size_t j = 0; j < shape1.size(); j++) { |
|
|
|
lhs_shape_[j + lhs_offset] = shape1[j]; |
|
|
|
if (need_broadcast_) { |
|
|
|
lhs_shape_[j + lhs_offset] = shape1[j]; |
|
|
|
} |
|
|
|
input1_num_ *= shape1[j]; |
|
|
|
} |
|
|
|
int rhs_offset = shape3.size() - shape2.size(); |
|
|
|
for (size_t k = 0; k < shape2.size(); k++) { |
|
|
|
rhs_shape_[k + rhs_offset] = shape2[k]; |
|
|
|
if (need_broadcast_) { |
|
|
|
rhs_shape_[k + rhs_offset] = shape2[k]; |
|
|
|
} |
|
|
|
input2_num_ *= shape2[k]; |
|
|
|
} |
|
|
|
|
|
|
|
|