Browse Source

support 3D format

tags/v1.1.0
liubuyu 5 years ago
parent
commit
b2ea8aeae0
2 changed files with 18 additions and 7 deletions
  1. +17
    -6
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  2. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc

+ 17
- 6
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -32,6 +32,7 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
@@ -64,20 +65,30 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
}
}

AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
CNodePtr trans_data = nullptr;
std::string InitDefaultFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// Init
std::string default_format = kOpFormat_DEFAULT;

if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
if (attr == kOpFormat_NCDHW) {
default_format = kOpFormat_NCDHW;
}
} else if (node->isa<ValueNode>() || node->isa<Parameter>()) {
auto out_format = AnfAlgo::GetOutputFormat(node, 0);
if (k3DFormatSet.find(out_format) != k3DFormatSet.end()) {
default_format = kOpFormat_NCDHW;
}
}
return default_format;
}

AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
CNodePtr trans_data = nullptr;
MS_EXCEPTION_IF_NULL(node);
// Init
std::string default_format = InitDefaultFormat(node);
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;


+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc View File

@@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || kOpFormat_NDC1HWC0) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);


Loading…
Cancel
Save