Browse Source

[GraphKernel] support lite_adapter for graph kernel.

tags/v1.6.0
chenlei_autodiff 4 years ago
parent
commit
4c0d5dcfe7
4 changed files with 170 additions and 1 deletions
  1. +4
    -0
      mindspore/ccsrc/backend/optimizer/CMakeLists.txt
  2. +123
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/lite_adapter/callback_impl.cc
  3. +41
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/lite_adapter/callback_impl.h
  4. +2
    -1
      tests/ut/cpp/CMakeLists.txt

+ 4
- 0
mindspore/ccsrc/backend/optimizer/CMakeLists.txt View File

@@ -43,6 +43,10 @@ if(ENABLE_AKG AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux")
file(GLOB_RECURSE _GK_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"graph_kernel/*.cc"
)
file(GLOB_RECURSE _GK_LITE_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"graph_kernel/lite_adapter/*.cc"
)
list(REMOVE_ITEM _GK_SRC_LIST ${_GK_LITE_LIST})
list(APPEND _PREACTIVATE_SRC_LIST ${_GK_SRC_LIST})
endif()



+ 123
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/lite_adapter/callback_impl.cc View File

@@ -0,0 +1,123 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/graph_kernel/lite_adapter/callback_impl.h"

#include <algorithm>
#include <string>
#include <tuple>
#include "base/core_ops.h"
#include "ir/dtype.h"
#include "utils/anf_utils.h"
#include "utils/utils.h"

namespace mindspore::graphkernel {
// register the callback object
GRAPH_KERNEL_CALLBACK_REGISTER(CallbackImpl);

KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t i) {
MS_EXCEPTION_IF_NULL(anf_node);
if (IsPrimitiveCNode(anf_node, prim::kPrimTupleGetItem)) {
return AnfUtils::VisitKernel(anf_node, 0);
}
auto node = anf_node->cast<CNodePtr>();
auto input_node = node->input(i + 1);
MS_EXCEPTION_IF_NULL(input_node);
return AnfUtils::VisitKernel(input_node, 0);
}

ShapeVector CallbackImpl::GetInputShape(const AnfNodePtr &node, size_t i) { return GetInputInferShape(node, i); }

ShapeVector CallbackImpl::GetOutputShape(const AnfNodePtr &node, size_t i) { return GetOutputInferShape(node, i); }

ShapeVector CallbackImpl::GetInputInferShape(const AnfNodePtr &node, size_t i) {
KernelWithIndex kernel_with_index = GetPrevNodeOutput(node, i);
return GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
}

ShapeVector CallbackImpl::GetOutputInferShape(const AnfNodePtr &node, size_t i) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>()) {
if (i == 0) {
return base_shape->cast<abstract::ShapePtr>()->shape();
}
MS_LOG(EXCEPTION) << "The node " << node->DebugString() << "is a single output node but got index [" << i;
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
if (i >= tuple_shape->size()) {
MS_LOG(EXCEPTION) << "Output index " << i << "is larger than output number " << tuple_shape->size()
<< " node:" << node->DebugString();
}
auto b_shp = (*tuple_shape)[i];
if (b_shp->isa<abstract::Shape>()) {
return b_shp->cast<abstract::ShapePtr>()->shape();
} else if (b_shp->isa<abstract::NoShape>()) {
return ShapeVector();
} else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << i
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString()
<< "node :" << node->DebugString();
}
} else if (base_shape->isa<abstract::NoShape>()) {
return ShapeVector();
}
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
<< base_shape->ToString() << " node : " << node->DebugString();
}

TypeId CallbackImpl::GetInputType(const AnfNodePtr &node, size_t i) { return GetInputInferType(node, i); }

TypeId CallbackImpl::GetOutputType(const AnfNodePtr &node, size_t i) { return GetOutputInferType(node, i); }

TypeId CallbackImpl::GetInputInferType(const AnfNodePtr &node, size_t i) {
KernelWithIndex kernel_with_index = GetPrevNodeOutput(node, i);
return GetOutputInferType(kernel_with_index.first, kernel_with_index.second);
}

TypeId CallbackImpl::GetOutputInferType(const AnfNodePtr &node, size_t i) {
MS_EXCEPTION_IF_NULL(node);
TypePtr type_ptr = node->Type();
MS_EXCEPTION_IF_NULL(type_ptr);
if (type_ptr->isa<Tuple>()) {
auto tuple_ptr = type_ptr->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (i >= tuple_ptr->size()) {
MS_LOG(EXCEPTION) << "Output index " << i << " must be less than output number " << tuple_ptr->size();
}
type_ptr = (*tuple_ptr)[i];
MS_EXCEPTION_IF_NULL(type_ptr);
}
if (type_ptr->isa<TensorType>()) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
}
return type_ptr->type_id();
}

std::string CallbackImpl::GetInputFormat(const AnfNodePtr &node, size_t i) { return kOpFormat_DEFAULT; }

std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) { return kOpFormat_DEFAULT; }

std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return "cpu"; }

std::string CallbackImpl::GetProcessorFromContext() { return "cpu"; }
} // namespace mindspore::graphkernel

+ 41
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/lite_adapter/callback_impl.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_CALLBACK_IMPL_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_CALLBACK_IMPL_H_
#include <string>
#include <utility>
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"

namespace mindspore::graphkernel {
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
class CallbackImpl : public Callback {
public:
ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) override;
ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) override;
ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) override;
ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) override;
TypeId GetInputType(const AnfNodePtr &node, size_t i) override;
TypeId GetOutputType(const AnfNodePtr &node, size_t i) override;
TypeId GetInputInferType(const AnfNodePtr &node, size_t i) override;
TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) override;
std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetProcessor(const AnfNodePtr &node) override;
std::string GetProcessorFromContext() override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_CALLBACK_IMPL_H_

+ 2
- 1
tests/ut/cpp/CMakeLists.txt View File

@@ -156,7 +156,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc"
"../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc"
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/*.cc"
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/model/*.cc"
"../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc"
"../../../mindspore/ccsrc/backend/session/ascend_session.cc"
"../../../mindspore/ccsrc/backend/session/ascend_auto_monad.cc"
@@ -201,6 +200,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/lite_adapter/callback_impl.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_compile.cc")
if(ENABLE_SECURITY)
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/profiler/device/profiling.cc")


Loading…
Cancel
Save