|
|
|
@@ -31,6 +31,24 @@ |
|
|
|
namespace mindspore {
|
|
|
|
namespace opt {
|
|
|
|
namespace graphkernel {
|
|
|
|
std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
|
|
|
|
bool is_int64 = true;
|
|
|
|
auto get_int_value = [&is_int64](const ValuePtr &value) -> int64_t {
|
|
|
|
if (value->isa<Int64Imm>()) {
|
|
|
|
return GetValue<int64_t>(value);
|
|
|
|
}
|
|
|
|
is_int64 = false;
|
|
|
|
return static_cast<int64_t>(GetValue<int>(value));
|
|
|
|
};
|
|
|
|
std::vector<int64_t> list_int;
|
|
|
|
const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
|
|
|
|
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
|
|
|
|
if (!is_int64) {
|
|
|
|
MS_LOG(WARNING) << "Vector type should be 'int64_t' but got 'int'";
|
|
|
|
}
|
|
|
|
return list_int;
|
|
|
|
}
|
|
|
|
|
|
|
|
void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
|
|
|
this->shape = InferShape(inputs, attrs);
|
|
|
|
this->type = InferType(inputs, attrs);
|
|
|
|
@@ -161,11 +179,11 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { |
|
|
|
}
|
|
|
|
|
|
|
|
DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|
|
|
return GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
|
|
|
|
return GetListInt(attrs.find("shape")->second);
|
|
|
|
}
|
|
|
|
|
|
|
|
DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|
|
|
auto new_shape = GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
|
|
|
|
auto new_shape = GetListInt(attrs.find("shape")->second);
|
|
|
|
auto origin_shape = inputs[0]->shape;
|
|
|
|
for (size_t i = 0; i < new_shape.size(); i++) {
|
|
|
|
if (new_shape[i] == -1) {
|
|
|
|
@@ -179,7 +197,7 @@ DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { |
|
|
|
}
|
|
|
|
|
|
|
|
DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|
|
|
auto axis = GetValue<std::vector<int64_t>>(attrs.find("axis")->second);
|
|
|
|
auto axis = GetListInt(attrs.find("axis")->second);
|
|
|
|
auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
|
|
|
|
if (keepdims) {
|
|
|
|
DShape new_shape = inputs[0]->shape;
|
|
|
|
|