Browse Source

fix reshape

tags/v1.5.0-rc1
Yang Jiao 4 years ago
parent
commit
2b4a784b86
2 changed files with 28 additions and 6 deletions
  1. +7
    -3
      mindspore/ccsrc/backend/optimizer/gpu/insert_format_transform_op.cc
  2. +21
    -3
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc

+ 7
- 3
mindspore/ccsrc/backend/optimizer/gpu/insert_format_transform_op.cc View File

@@ -100,9 +100,13 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
MS_EXCEPTION_IF_NULL(transpose_op);
// 3.Set the output info of transpose.
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
if (!is_fake) {
auto transpose_shape = AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index);
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, {transpose_shape}, transpose_op.get());
if (is_fake) {
std::vector<int64_t> shape;
std::transform(transpose_shape.begin(), transpose_shape.end(), std::back_inserter(shape), SizeToLong);
AnfAlgo::SetNodeAttr("shape", MakeValue(shape), transpose_op);
} else {
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
}
// 4. Set the new edge of transpose op.


+ 21
- 3
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc View File

@@ -31,6 +31,24 @@
namespace mindspore {
namespace opt {
namespace graphkernel {
std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
bool is_int64 = true;
auto get_int_value = [&is_int64](const ValuePtr &value) -> int64_t {
if (value->isa<Int64Imm>()) {
return GetValue<int64_t>(value);
}
is_int64 = false;
return static_cast<int64_t>(GetValue<int>(value));
};
std::vector<int64_t> list_int;
const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
if (!is_int64) {
MS_LOG(WARNING) << "Vector type should be 'int64_t' but got 'int'";
}
return list_int;
}
void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
this->shape = InferShape(inputs, attrs);
this->type = InferType(inputs, attrs);
@@ -161,11 +179,11 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
}
DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
return GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
return GetListInt(attrs.find("shape")->second);
}
DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto new_shape = GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
auto new_shape = GetListInt(attrs.find("shape")->second);
auto origin_shape = inputs[0]->shape;
for (size_t i = 0; i < new_shape.size(); i++) {
if (new_shape[i] == -1) {
@@ -179,7 +197,7 @@ DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
}
DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto axis = GetValue<std::vector<int64_t>>(attrs.find("axis")->second);
auto axis = GetListInt(attrs.find("axis")->second);
auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
if (keepdims) {
DShape new_shape = inputs[0]->shape;


Loading…
Cancel
Save