Browse Source

!25683 Add GraphKernelCallback functions, and call them in AkgKernelJsonGenerator.

Merge pull request !25683 from DeshiChen/1025_genjson
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
c48778df30
7 changed files with 323 additions and 79 deletions
  1. +32
    -64
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc
  2. +8
    -12
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h
  3. +79
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.cc
  4. +39
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.h
  5. +22
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.cc
  6. +139
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.h
  7. +4
    -3
      mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.cc

+ 32
- 64
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc View File

@@ -30,21 +30,17 @@ namespace mindspore::graphkernel {
using kernel::GetInputIndex;
using kernel::GetKernelInput;
using kernel::GetOutputIndex;
using kernel::GetStrProcessorFromContext;
using kernel::OpAttr;
using kernel::OpImplyType;
using kernel::OpInfo;
using kernel::OpIOInfo;
namespace {
std::vector<int> GetDynInputSize(const AnfNodePtr &anf_node) {
std::vector<int> dyn_input_sizes;
std::vector<int64_t> GetDynInputSizes(const AnfNodePtr &anf_node) {
std::vector<int64_t> dyn_input_sizes;
auto primitive = GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->HasAttr(kAttrDynInputSizes)) {
std::vector<int64_t> dyn_input_sizes_me =
GetValue<const std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
(void)std::transform(dyn_input_sizes_me.begin(), dyn_input_sizes_me.end(), std::back_inserter(dyn_input_sizes),
[](const int64_t &value) { return static_cast<int>(value); });
dyn_input_sizes = GetValue<const std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
}
return dyn_input_sizes;
}
@@ -68,7 +64,7 @@ class OpInfoExtractor {

private:
void ExtractInputs(const OpInfoPtr &op_info) const {
auto dyn_input_sizes = GetDynInputSize(cnode_);
auto dyn_input_sizes = GetDynInputSizes(cnode_);
if (dyn_input_sizes.empty()) {
for (size_t i = 1; i < cnode_->size(); i++) {
auto io_info = std::make_shared<OpIOInfo>();
@@ -142,34 +138,6 @@ class OpInfoExtractor {
};
} // namespace

TypeId AkgKernelJsonGenerator::GetInputDataType(const AnfNodePtr &anf_node, size_t real_index) const {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index)
: AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
}

std::vector<size_t> AkgKernelJsonGenerator::GetInputShape(const AnfNodePtr &anf_node, size_t real_index) const {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index)
: AnfAlgo::GetInputDeviceShape(anf_node, real_index);
}

std::string AkgKernelJsonGenerator::GetInputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetInputFormat(anf_node, real_index);
}

TypeId AkgKernelJsonGenerator::GetOutputDataType(const AnfNodePtr &anf_node, size_t index) const {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetOutputInferDataType(anf_node, index)
: AnfAlgo::GetOutputDeviceDataType(anf_node, index);
}

std::vector<size_t> AkgKernelJsonGenerator::GetOutputShape(const AnfNodePtr &anf_node, size_t index) const {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetOutputInferShape(anf_node, index)
: AnfAlgo::GetOutputDeviceShape(anf_node, index);
}

std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf_node, size_t index) const {
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(anf_node, index);
}

bool AkgKernelJsonGenerator::GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx,
nlohmann::json *node_json) const {
MS_EXCEPTION_IF_NULL(anf_node);
@@ -242,7 +210,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
}

// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
auto dyn_input_sizes = GetDynInputSize(anf_node);
auto dyn_input_sizes = GetDynInputSizes(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < inputs_ptr.size(); i++) {
auto input_ptr = inputs_ptr[i];
@@ -251,10 +219,10 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
return false;
}

size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]);
size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : LongToSize(dyn_input_sizes[i]);
std::vector<nlohmann::json> input_list;
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
auto type_id = this->GetInputDataType(anf_node, real_input_index);
auto type_id = this->cb_->GetInputType(anf_node, real_input_index);
std::string dtype = TypeIdToString(type_id, true);
if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] input [" << real_input_index
@@ -263,13 +231,13 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
}
nlohmann::json input_desc_json;
input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index);
input_desc_json[kJsonKeyFormat] = this->cb_->GetInputFormat(anf_node, real_input_index);
input_desc_json[kJsonKeyName] = input_ptr->name();
input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
auto input_shape = this->GetInputShape(anf_node, real_input_index);
auto input_shape = this->cb_->GetInputShape(anf_node, real_input_index);
if (!is_basic_op_ && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
MS_LOG(WARNING) << "Pick single value [" << input_desc_json[kJsonKeyValue] << "] from input["
<< real_input_index << "] of node [" << anf_node->DebugString(2);
MS_LOG(DEBUG) << "Pick single value [" << input_desc_json[kJsonKeyValue] << "] from input[" << real_input_index
<< "] of node [" << anf_node->DebugString(2);
input_shape.clear();
}
if (input_shape.empty()) {
@@ -294,7 +262,7 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co
auto outputs = op_info->outputs_ptr();
for (size_t i = 0; i < output_tensor_num; i++) {
nlohmann::json output_json;
auto type_id = this->GetOutputDataType(anf_node, i);
auto type_id = this->cb_->GetOutputType(anf_node, i);
std::string dtype = TypeIdToString(type_id, true);
if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] output [" << i << "] data type is null. ";
@@ -303,10 +271,10 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co

std::string output_name = outputs[i]->name();
output_json[kJsonKeyDataType] = dtype;
output_json[kJsonKeyFormat] = this->GetOutputFormat(anf_node, i);
output_json[kJsonKeyFormat] = this->cb_->GetOutputFormat(anf_node, i);
output_json[kJsonKeyName] = output_name;
output_json[kJsonKeyTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
auto output_shape = this->GetOutputShape(anf_node, i);
auto output_shape = this->cb_->GetOutputShape(anf_node, i);
if (output_shape.empty()) {
output_shape.push_back(1);
}
@@ -316,7 +284,7 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co
return true;
}

void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int64_t> &dyn_input_sizes,
const OpAttrPtr &op_attr, nlohmann::json *attr_json,
const ValuePtr &attr_value) {
MS_EXCEPTION_IF_NULL(anf_node);
@@ -350,7 +318,7 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
if (op_attr->name() == kJsonKeyDataformat) {
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
auto input_format = this->GetInputFormat(anf_node, format_i);
auto input_format = this->cb_->GetInputFormat(anf_node, format_i);
data_format.push_back(input_format);
}
} else {
@@ -369,7 +337,7 @@ bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, cons
MS_LOG(DEBUG) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty";
return true;
}
auto dyn_input_sizes = GetDynInputSize(anf_node);
auto dyn_input_sizes = GetDynInputSizes(anf_node);
auto primitive = GetCNodePrimitive(anf_node);

// create input name list for "x_shape" in attr with "x" in primitive.
@@ -393,15 +361,15 @@ bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, cons
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
return false;
}
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->second], 0));
for (int input_i = 0; input_i < dyn_input_sizes[find_item->second]; input_i++) {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
size_t tensor_idx = LongToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->second], 0));
for (int64_t input_i = 0; input_i < dyn_input_sizes[find_item->second]; input_i++) {
attr_json[kJsonKeyValue] = this->cb_->GetInputInferShape(anf_node, tensor_idx);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
tensor_idx++;
}
} else {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->second);
attr_json[kJsonKeyValue] = this->cb_->GetInputInferShape(anf_node, find_item->second);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
}
@@ -568,7 +536,7 @@ bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node
}

size_t AkgKernelJsonGenerator::GetTensorSize(const nlohmann::json &node_json) const {
const std::vector<size_t> &shape = node_json[kJsonKeyShape];
const ShapeVector &shape = node_json[kJsonKeyShape];
const std::string &dtype = node_json[kJsonKeyDataType];
auto type_ptr = StringToType(dtype);
MS_EXCEPTION_IF_NULL(type_ptr);
@@ -617,7 +585,7 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
(*kernel_json)[kJsonKeyId] = 0; // unused key
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetStrProcessorFromContext(); // GetProcessorStr(anf_node);
(*kernel_json)[kJsonKeyProcess] = this->cb_->GetProcessorFromContext();
(*kernel_json)[kJsonKeyComposite] = false;
if (dump_option_.get_compute_capability) {
(*kernel_json)[kJsonKeyComputeCapability] = ComputeCapability::Get();
@@ -711,7 +679,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyId] = 0; // unused key
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetStrProcessorFromContext();
(*kernel_json)[kJsonKeyProcess] = this->cb_->GetProcessorFromContext();
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
if (dump_option_.get_compute_capability) {
@@ -748,12 +716,12 @@ bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_n
void AkgKernelJsonGenerator::UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
for (auto const &anf_node : anf_nodes) {
auto dyn_input_sizes = GetDynInputSize(anf_node);
auto dyn_input_sizes = GetDynInputSizes(anf_node);
bool is_dynamic_input = !dyn_input_sizes.empty();
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < input_num; ++i) {
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
size_t input_tensor_num = is_dynamic_input ? LongToSize(dyn_input_sizes[i]) : 1;
for (size_t j = 0; j < input_tensor_num; ++j) {
auto tmp_input = GetKernelInput(anf_node, real_input_index);
std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kJsonKeyInputDesc, std::make_pair(i, j));
@@ -781,14 +749,14 @@ nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNod
auto input_index = GetInputIndex(anf_nodes, input_list);
for (size_t i = 0; i < input_index.size(); ++i) {
auto tmp_input = input_index[i];
auto type_id = this->GetInputDataType(tmp_input.first, tmp_input.second.first);
auto type_id = this->cb_->GetInputType(tmp_input.first, tmp_input.second.first);
std::string dtype = TypeIdToString(type_id, true);
nlohmann::json input_desc_json;
input_desc_json[kJsonKeyTensorName] =
GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second);
input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first);
auto input_shape = this->GetInputShape(tmp_input.first, tmp_input.second.first);
input_desc_json[kJsonKeyFormat] = this->cb_->GetInputFormat(tmp_input.first, tmp_input.second.first);
auto input_shape = this->cb_->GetInputShape(tmp_input.first, tmp_input.second.first);
if (input_shape.empty()) {
input_shape.push_back(1);
}
@@ -878,13 +846,13 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNo
}
}
if (!found) {
auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second);
auto type_id = this->cb_->GetOutputType(tmp_output.first, tmp_output.second);
std::string dtype = TypeIdToString(type_id, true);
output_desc_json[kJsonKeyTensorName] =
GetTensorName(node_json_map.at(tmp_output.first), kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
output_desc_json[kJsonKeyDataType] = dtype;
output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second);
auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second);
output_desc_json[kJsonKeyFormat] = this->cb_->GetOutputFormat(tmp_output.first, tmp_output.second);
auto output_shape = this->cb_->GetOutputShape(tmp_output.first, tmp_output.second);
if (output_shape.empty()) {
output_shape.push_back(1);
}


+ 8
- 12
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h View File

@@ -22,8 +22,9 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "backend/kernel_compiler/oplib/opinfo.h"
#include "nlohmann/json.hpp"
#include "backend/kernel_compiler/oplib/opinfo.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"

namespace mindspore::graphkernel {
using kernel::OpAttrPtr;
@@ -89,8 +90,8 @@ class ComputeCapability {

class AkgKernelJsonGenerator {
public:
AkgKernelJsonGenerator() = default;
explicit AkgKernelJsonGenerator(DumpOption dump_option) : dump_option_(std::move(dump_option)) {}
explicit AkgKernelJsonGenerator(DumpOption dump_option)
: dump_option_(std::move(dump_option)), cb_(Callback::Instance()) {}
~AkgKernelJsonGenerator() = default;

bool CollectJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json);
@@ -99,19 +100,19 @@ class AkgKernelJsonGenerator {
bool CollectJson(const AnfNodePtr &anf_node);
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *node_json);
std::string kernel_name() const { return kernel_name_; }
nlohmann::json kernel_json() const { return kernel_json_; }
std::string kernel_json_str() const { return kernel_json_.dump(); }
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; }
std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; }

private:
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *node_json);
bool CreateInputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *inputs_json);
bool CreateOutputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *outputs_json);
void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes, const OpAttrPtr &op_attr,
void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int64_t> &dyn_input_sizes, const OpAttrPtr &op_attr,
nlohmann::json *attr_json, const ValuePtr &attr_value);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json);
void GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map,
@@ -131,12 +132,6 @@ class AkgKernelJsonGenerator {
nlohmann::json *node_json) const;
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position) const;
TypeId GetInputDataType(const AnfNodePtr &anf_node, size_t real_index) const;
std::vector<size_t> GetInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
TypeId GetOutputDataType(const AnfNodePtr &anf_node, size_t index) const;
std::vector<size_t> GetOutputShape(const AnfNodePtr &anf_node, size_t index) const;
std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index) const;
void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json);
OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node) const;
void CollectParallelDimInfo(const AnfNodePtr &anf_node);
@@ -155,6 +150,7 @@ class AkgKernelJsonGenerator {
std::vector<size_t> output_size_list_;
std::map<std::string, AnfNodePtr> address_node_map_;
bool is_basic_op_{false};
Callback *cb_{nullptr};
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_

+ 79
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.cc View File

@@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/graph_kernel/adapter/callback_impl.h"

#include <algorithm>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"

namespace mindspore::graphkernel {
// register the callback object
GRAPH_KERNEL_CALLBACK_REGISTER(CallbackImpl);

ShapeVector CallbackImpl::GetInputShape(const AnfNodePtr &node, size_t i) {
auto vec = AnfAlgo::GetInputDeviceShape(node, i);
ShapeVector ret;
std::transform(vec.begin(), vec.end(), std::back_inserter(ret), SizeToLong);
return ret;
}

ShapeVector CallbackImpl::GetOutputShape(const AnfNodePtr &node, size_t i) {
auto vec = AnfAlgo::GetOutputDeviceShape(node, i);
ShapeVector ret;
std::transform(vec.begin(), vec.end(), std::back_inserter(ret), SizeToLong);
return ret;
}

ShapeVector CallbackImpl::GetInputInferShape(const AnfNodePtr &node, size_t i) {
auto vec = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
ShapeVector ret;
std::transform(vec.begin(), vec.end(), std::back_inserter(ret), SizeToLong);
return ret;
}

ShapeVector CallbackImpl::GetOutputInferShape(const AnfNodePtr &node, size_t i) {
auto vec = AnfAlgo::GetOutputInferShape(node, i);
ShapeVector ret;
std::transform(vec.begin(), vec.end(), std::back_inserter(ret), SizeToLong);
return ret;
}

TypeId CallbackImpl::GetInputType(const AnfNodePtr &node, size_t i) { return AnfAlgo::GetInputDeviceDataType(node, i); }

TypeId CallbackImpl::GetOutputType(const AnfNodePtr &node, size_t i) {
return AnfAlgo::GetOutputDeviceDataType(node, i);
}

TypeId CallbackImpl::GetInputInferType(const AnfNodePtr &node, size_t i) {
return AnfAlgo::GetPrevNodeOutputInferDataType(node, i);
}

TypeId CallbackImpl::GetOutputInferType(const AnfNodePtr &node, size_t i) {
return AnfAlgo::GetOutputInferDataType(node, i);
}

std::string CallbackImpl::GetInputFormat(const AnfNodePtr &node, size_t i) { return AnfAlgo::GetInputFormat(node, i); }

std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) {
return AnfAlgo::GetOutputFormat(node, i);
}

std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); }

std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); }
} // namespace mindspore::graphkernel

+ 39
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/adapter/callback_impl.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_
#include <string>
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"

namespace mindspore::graphkernel {
class CallbackImpl : public Callback {
public:
ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) override;
ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) override;
ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) override;
ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) override;
TypeId GetInputType(const AnfNodePtr &node, size_t i) override;
TypeId GetOutputType(const AnfNodePtr &node, size_t i) override;
TypeId GetInputInferType(const AnfNodePtr &node, size_t i) override;
TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) override;
std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetProcessor(const AnfNodePtr &node) override;
std::string GetProcessorFromContext() override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_

+ 22
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.cc View File

@@ -0,0 +1,22 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
#include <memory>

namespace mindspore::graphkernel {
std::unique_ptr<Callback> Callback::instance_{nullptr};
} // namespace mindspore::graphkernel

+ 139
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_callback.h View File

@@ -0,0 +1,139 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_
#include <string>
#include <memory>
#include <functional>

#include "ir/anf.h"
#include "ir/dtype/type_id.h"
#include "utils/shape_utils.h"

namespace mindspore::graphkernel {
class Callback {
public:
static Callback *Instance() { return instance_.get(); }

/**
* @brief Get the real input shape of the `node`.
*
* @param node the AnfNodePtr
* @param i the input index, start from 0
*/
virtual ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the real output shape of the `node`.
*
* @param node the AnfNodePtr
* @param i the output index, start from 0
*/
virtual ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the inferred input shape of the `node`.
*
* @param node the AnfNodePtr
* @param i the input index, start from 0
*/
virtual ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the inferred output shape of the `node`.
*
* @param node the AnfNodePtr
* @param i the output index, start from 0
*/
virtual ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the real input data type of the `node`.
*
* @param node the AnfNodePtr
* @param i the input index, start from 0
*/
virtual TypeId GetInputType(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the real output data type of the `node`.
*
* @param node the AnfNodePtr
* @param i the output index, start from 0
*/
virtual TypeId GetOutputType(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the inferred input data type of the `node`.
*
* @param node the AnfNodePtr
* @param i the input index, start from 0
*/
virtual TypeId GetInputInferType(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the inferred output data type of the `node`.
*
* @param node the AnfNodePtr
* @param i the output index, start from 0
*/
virtual TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the input data format of the `node`.
*
* @param node the AnfNodePtr
* @param i the input index, start from 0
*/
virtual std::string GetInputFormat(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the output data format of the `node`.
*
* @param node the AnfNodePtr
* @param i the output index, start from 0
*/
virtual std::string GetOutputFormat(const AnfNodePtr &node, size_t i) = 0;

/**
* @brief Get the processor of the `node`.
*
* @param node the AnfNodePtr
*/
virtual std::string GetProcessor(const AnfNodePtr &node) = 0;

/**
* @brief Get the backend processor from context.
*/
virtual std::string GetProcessorFromContext() = 0;

private:
friend class CallbackImplRegister;
static void RegImpl(Callback *cb) { instance_.reset(cb); }

static std::unique_ptr<Callback> instance_;
};

class CallbackImplRegister {
public:
explicit CallbackImplRegister(std::function<Callback *()> fn) { Callback::RegImpl(fn()); }
~CallbackImplRegister() = default;
};

#define GRAPH_KERNEL_CALLBACK_REGISTER(cls) \
static const CallbackImplRegister g_graphkernel_callback([]() { return static_cast<Callback *>(new cls()); })
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_CALLBACK_H_

+ 4
- 3
mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.cc View File

@@ -231,8 +231,8 @@ bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph,

std::vector<AnfNodePtr> new_cast_nodes;
for (const auto &index : op_input_indexes) {
auto new_cast_node =
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
auto new_cast_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
AnfAlgo::GetInputNode(type_insens_node, index)});
NodeIOInfo cast_io_info;
cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
cast_io_info.outputs_format = cast_io_info.inputs_format;
@@ -307,7 +307,8 @@ bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, co
SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
SetNodeInfo(node, new_type_insens_node, type_insens_io_info);

auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
auto new_cast_node =
func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), new_type_insens_node});
NodeIOInfo cast_io_info;
cast_io_info.inputs_format.push_back(pattern_output_format);
cast_io_info.outputs_format = cast_io_info.inputs_format;


Loading…
Cancel
Save