|
|
@@ -26,15 +26,13 @@ |
|
|
#include "runtime/device/kernel_info.h" |
|
|
#include "runtime/device/kernel_info.h" |
|
|
#include "utils/ms_context.h" |
|
|
#include "utils/ms_context.h" |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
|
|
namespace opt { |
|
|
|
|
|
|
|
|
namespace mindspore::opt { |
|
|
constexpr size_t kFloat16Len = 2; // size of float16; |
|
|
constexpr size_t kFloat16Len = 2; // size of float16; |
|
|
constexpr size_t kTopkIndexK = 1; |
|
|
constexpr size_t kTopkIndexK = 1; |
|
|
namespace { |
|
|
namespace { |
|
|
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
|
tensor::TensorPtr CreateTensor() { |
|
|
// 1 create tensor |
|
|
// 1 create tensor |
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); |
|
|
|
|
|
auto last_dim = shape[shape.size() - 1]; |
|
|
|
|
|
|
|
|
const size_t last_dim = 4096; |
|
|
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)}; |
|
|
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)}; |
|
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); |
|
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); |
|
|
MS_EXCEPTION_IF_NULL(tensor_type); |
|
|
MS_EXCEPTION_IF_NULL(tensor_type); |
|
|
@@ -63,8 +61,8 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { |
|
|
return indices_tensor; |
|
|
return indices_tensor; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ValueNodePtr CreateValueNode(const AnfNodePtr &node) { |
|
|
|
|
|
tensor::TensorPtr indices_tensor = CreateTensor(node); |
|
|
|
|
|
|
|
|
ValueNodePtr CreateValueNode() { |
|
|
|
|
|
tensor::TensorPtr indices_tensor = CreateTensor(); |
|
|
MS_EXCEPTION_IF_NULL(indices_tensor); |
|
|
MS_EXCEPTION_IF_NULL(indices_tensor); |
|
|
auto indices_const = std::make_shared<ValueNode>(indices_tensor); |
|
|
auto indices_const = std::make_shared<ValueNode>(indices_tensor); |
|
|
MS_EXCEPTION_IF_NULL(indices_const); |
|
|
MS_EXCEPTION_IF_NULL(indices_const); |
|
|
@@ -159,14 +157,14 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod |
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
auto tensor = value->cast<tensor::TensorPtr>(); |
|
|
auto tensor = value->cast<tensor::TensorPtr>(); |
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c()); |
|
|
|
|
|
|
|
|
auto *data = reinterpret_cast<int32_t *>(tensor->data_c()); |
|
|
MS_EXCEPTION_IF_NULL(data); |
|
|
MS_EXCEPTION_IF_NULL(data); |
|
|
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data)); |
|
|
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data)); |
|
|
new_cnode->set_input(kTopkIndexK + 1, new_value_node); |
|
|
new_cnode->set_input(kTopkIndexK + 1, new_value_node); |
|
|
|
|
|
|
|
|
std::unordered_set<size_t> attr_index{kTopkIndexK}; |
|
|
std::unordered_set<size_t> attr_index{kTopkIndexK}; |
|
|
ConstInputToAttr(new_cnode, attr_index); |
|
|
ConstInputToAttr(new_cnode, attr_index); |
|
|
auto indices_const = CreateValueNode(new_cnode); |
|
|
|
|
|
|
|
|
auto indices_const = CreateValueNode(); |
|
|
new_cnode->add_input(indices_const); |
|
|
new_cnode->add_input(indices_const); |
|
|
MS_EXCEPTION_IF_NULL(supported_checker_); |
|
|
MS_EXCEPTION_IF_NULL(supported_checker_); |
|
|
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { |
|
|
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { |
|
|
@@ -181,5 +179,4 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod |
|
|
|
|
|
|
|
|
return new_cnode; |
|
|
return new_cnode; |
|
|
} |
|
|
} |
|
|
} // namespace opt |
|
|
|
|
|
} // namespace mindspore |
|
|
|
|
|
|
|
|
} // namespace mindspore::opt |