Browse Source

!9072 fix topk fussion help tensor size = 4096*2

From: @jjfeing
Reviewed-by: @chujinjin,@kisnwang
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3fa7233f68
2 changed files with 9 additions and 12 deletions
  1. +8
    -11
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc
  2. +1
    -1
      tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc

+ 8
- 11
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc View File

@@ -26,15 +26,13 @@
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"


namespace mindspore {
namespace opt {
namespace mindspore::opt {
constexpr size_t kFloat16Len = 2; // size of float16; constexpr size_t kFloat16Len = 2; // size of float16;
constexpr size_t kTopkIndexK = 1; constexpr size_t kTopkIndexK = 1;
namespace { namespace {
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
tensor::TensorPtr CreateTensor() {
// 1 create tensor // 1 create tensor
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
auto last_dim = shape[shape.size() - 1];
const size_t last_dim = 4096;
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)}; std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)};
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
MS_EXCEPTION_IF_NULL(tensor_type); MS_EXCEPTION_IF_NULL(tensor_type);
@@ -63,8 +61,8 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
return indices_tensor; return indices_tensor;
} }


ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
tensor::TensorPtr indices_tensor = CreateTensor(node);
ValueNodePtr CreateValueNode() {
tensor::TensorPtr indices_tensor = CreateTensor();
MS_EXCEPTION_IF_NULL(indices_tensor); MS_EXCEPTION_IF_NULL(indices_tensor);
auto indices_const = std::make_shared<ValueNode>(indices_tensor); auto indices_const = std::make_shared<ValueNode>(indices_tensor);
MS_EXCEPTION_IF_NULL(indices_const); MS_EXCEPTION_IF_NULL(indices_const);
@@ -159,14 +157,14 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<tensor::TensorPtr>(); auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c());
auto *data = reinterpret_cast<int32_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data)); auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data));
new_cnode->set_input(kTopkIndexK + 1, new_value_node); new_cnode->set_input(kTopkIndexK + 1, new_value_node);


std::unordered_set<size_t> attr_index{kTopkIndexK}; std::unordered_set<size_t> attr_index{kTopkIndexK};
ConstInputToAttr(new_cnode, attr_index); ConstInputToAttr(new_cnode, attr_index);
auto indices_const = CreateValueNode(new_cnode);
auto indices_const = CreateValueNode();
new_cnode->add_input(indices_const); new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_); MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
@@ -181,5 +179,4 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod


return new_cnode; return new_cnode;
} }
} // namespace opt
} // namespace mindspore
} // namespace mindspore::opt

+ 1
- 1
tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc View File

@@ -89,7 +89,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
EXPECT_TRUE(value_node->value()->isa<tensor::Tensor>()); EXPECT_TRUE(value_node->value()->isa<tensor::Tensor>());
auto tensor = value_node->value()->cast<tensor::TensorPtr>(); auto tensor = value_node->value()->cast<tensor::TensorPtr>();
EXPECT_EQ(tensor->shape().size(), 1); EXPECT_EQ(tensor->shape().size(), 1);
EXPECT_EQ(tensor->shape()[0], 8);
EXPECT_EQ(tensor->shape()[0], 4096*2);
} }


TEST_F(TestHWTopKSplit, test_topk_no_split) { TEST_F(TestHWTopKSplit, test_topk_no_split) {


Loading…
Cancel
Save