|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" |
|
|
|
|
|
|
|
#include <algorithm> |
|
|
|
#include <map> |
|
|
|
#include <set> |
|
|
|
#include <tuple> |
|
|
|
@@ -121,6 +122,84 @@ bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const |
|
|
|
MS_LOG(INFO) << "Collect fusion json: " << fused_name; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool TensorElementAllTheSame(const tensor::TensorPtr &tensor) { |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
if (tensor->DataSize() == 1) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
auto data = static_cast<char *>(tensor->data_c()); |
|
|
|
auto itemsize = static_cast<size_t>(tensor->data().itemsize()); |
|
|
|
auto total_cnt = static_cast<size_t>(tensor->DataSize()); |
|
|
|
for (size_t i = 1; i < total_cnt; ++i) { |
|
|
|
for (size_t ei = 0; ei < itemsize; ++ei) { |
|
|
|
if (data[ei] != data[i * itemsize + ei]) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) { |
|
|
|
auto tensor = GetValueNode<tensor::TensorPtr>(value_node); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto type_id = tensor->data_type(); |
|
|
|
ShapeVector new_shape; |
|
|
|
auto origin_ndim = static_cast<size_t>(tensor->DataDim()); |
|
|
|
for (size_t i = 0; i < origin_ndim; ++i) { |
|
|
|
new_shape.push_back(1); |
|
|
|
} |
|
|
|
tensor::TensorPtr scalar_tensor = std::make_shared<tensor::Tensor>(type_id, new_shape); |
|
|
|
scalar_tensor->set_device_info(tensor->device_info()); |
|
|
|
auto data_ptr = scalar_tensor->data_c(); |
|
|
|
MS_EXCEPTION_IF_NULL(data_ptr); |
|
|
|
auto itemsize = static_cast<size_t>(tensor->data().itemsize()); |
|
|
|
if (memcpy_s(data_ptr, static_cast<size_t>(itemsize), tensor->data_c(), itemsize) != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Failed to copy data from tensor into scalar."; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(scalar_tensor); |
|
|
|
new_value_node->set_abstract(scalar_tensor->ToAbstract()); |
|
|
|
new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{GetFormat(value_node)}); |
|
|
|
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_id}); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); |
|
|
|
|
|
|
|
return new_value_node; |
|
|
|
} |
|
|
|
|
|
|
|
void ReplaceTensorWithScalar(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &scalar_tensors) { |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
if (scalar_tensors.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto sub_mng = fg->manager(); |
|
|
|
if (sub_mng == nullptr) { |
|
|
|
sub_mng = Manage(fg, true); |
|
|
|
fg->set_manager(sub_mng); |
|
|
|
} |
|
|
|
|
|
|
|
std::map<AnfNodePtr, AnfNodePtr> to_be_replaced; |
|
|
|
for (auto scalar_tensor_node : scalar_tensors) { |
|
|
|
auto scalar = ConvertToScalarTensor(scalar_tensor_node); |
|
|
|
auto format = GetFormat(scalar_tensor_node); |
|
|
|
auto dst_shape_vec = GetShape(scalar_tensor_node); |
|
|
|
AnfNodePtrList new_broadcast_inputs = {NewValueNode(prim::kPrimBroadcastTo), scalar}; |
|
|
|
auto broadcast_node = CreateCNode(new_broadcast_inputs, fg, |
|
|
|
{.format = format, .shape = dst_shape_vec, .type = GetType(scalar_tensor_node)}); |
|
|
|
auto device_shape = GetDeviceShape(scalar_tensor_node); |
|
|
|
SetNodeAttrSafely("shape", MakeValue(device_shape), broadcast_node); |
|
|
|
to_be_replaced[scalar_tensor_node] = broadcast_node; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto [old_value_node, new_node] : to_be_replaced) { |
|
|
|
sub_mng->Replace(old_value_node, new_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) { |
|
|
|
@@ -128,20 +207,28 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i |
|
|
|
auto nodes = TopoSort(fg->get_return()); |
|
|
|
|
|
|
|
OrderedMap<ValuePtr, AnfNodePtrList> vmap; |
|
|
|
std::vector<AnfNodePtr> scalar_tensors; |
|
|
|
for (const auto &node : nodes) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
auto tnode = inputs[i]; |
|
|
|
const auto &tnode = inputs[i]; |
|
|
|
auto tensor = GetValueNode<tensor::TensorPtr>(tnode); |
|
|
|
if (tensor && (tensor->DataSize() > 1)) { |
|
|
|
if (tensor == nullptr || tensor->DataSize() == 1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (TensorElementAllTheSame(tensor)) { |
|
|
|
scalar_tensors.emplace_back(tnode); |
|
|
|
} else { |
|
|
|
vmap[GetValueNode(tnode)].push_back(tnode); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ReplaceTensorWithScalar(fg, scalar_tensors); |
|
|
|
|
|
|
|
if (vmap.empty()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -169,6 +256,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i |
|
|
|
|
|
|
|
inputs.push_back(vnode); |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -660,7 +748,22 @@ ShapeVector GetShape(const AnfNodePtr &node) { |
|
|
|
if (shape == nullptr || !shape->isa<abstract::Shape>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot get shape from " << node->fullname_with_scope(); |
|
|
|
} |
|
|
|
return shape->cast<abstract::ShapePtr>()->shape(); |
|
|
|
auto shape_vec = shape->cast<abstract::ShapePtr>()->shape(); |
|
|
|
if (shape_vec.empty()) { |
|
|
|
shape_vec.push_back(1); |
|
|
|
} |
|
|
|
return shape_vec; |
|
|
|
} |
|
|
|
|
|
|
|
ShapeVector GetDeviceShape(const AnfNodePtr &node) { |
|
|
|
ShapeVector res_device_shape; |
|
|
|
auto device_shape = AnfAlgo::GetOutputDeviceShape(node, 0); |
|
|
|
if (device_shape.empty()) { |
|
|
|
res_device_shape.push_back(1); |
|
|
|
} else { |
|
|
|
std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(res_device_shape), SizeToLong); |
|
|
|
} |
|
|
|
return res_device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) { |
|
|
|
|