| @@ -444,14 +444,9 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod | |||||
| if (!anf_node->isa<CNode>()) { | if (!anf_node->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; | MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; | ||||
| } | } | ||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (input_idx + 1 >= cnode->inputs().size()) { | |||||
| MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); | |||||
| } | |||||
| auto node = cnode->input(input_idx + 1); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| return VisitKernelWithReturnType(node, 0); | |||||
| auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| return VisitKernelWithReturnType(input_node, 0); | |||||
| } | } | ||||
| std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { | std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { | ||||
| @@ -975,7 +970,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { | |||||
| AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { | AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto get_input_index = index + 1; | auto get_input_index = index + 1; | ||||
| if (index + 1 > node->inputs().size()) { | |||||
| if (index + 1 >= node->inputs().size()) { | |||||
| MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" | MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" | ||||
| << node->inputs().size(); | << node->inputs().size(); | ||||
| } | } | ||||
| @@ -1061,5 +1061,10 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) { | |||||
| RunInfer(func_graph, inputs); | |||||
| return CompileGraph(func_graph); | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -52,6 +52,7 @@ class AscendSession : public SessionBasic { | |||||
| } | } | ||||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | ||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | ||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| void BuildGraph(GraphId) override; | void BuildGraph(GraphId) override; | ||||
| void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "pipeline/jit/parse/data_converter.h" | #include "pipeline/jit/parse/data_converter.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| @@ -1038,6 +1039,45 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { | |||||
| void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } | void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } | ||||
| void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) { | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| size_t tensor_index = 0; | |||||
| for (const auto &node : node_list) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| AbstractBasePtrList input_abstracts; | |||||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { | |||||
| auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| auto abstract = input_node->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| input_abstracts.emplace_back(abstract); | |||||
| } | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(node); | |||||
| if (prim->isa<PrimitiveC>()) { | |||||
| auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>(); | |||||
| MS_EXCEPTION_IF_NULL(prim_c); | |||||
| auto abstract = prim_c->Infer(input_abstracts); | |||||
| node->set_abstract(abstract); | |||||
| } else { | |||||
| node->set_abstract( | |||||
| std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int>{32, 64, 218, 218})->ToAbstract()); | |||||
| } | |||||
| } else if (node->isa<Parameter>()) { | |||||
| if (tensor_index > inputs.size()) { | |||||
| MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size"; | |||||
| } | |||||
| node->set_abstract(inputs[tensor_index++]->ToAbstract()); | |||||
| } else { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| value_node->set_abstract(value->ToAbstract()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void SessionBasic::SetSummaryNodes(KernelGraph *graph) { | void SessionBasic::SetSummaryNodes(KernelGraph *graph) { | ||||
| MS_LOG(DEBUG) << "Update summary Start"; | MS_LOG(DEBUG) << "Update summary Start"; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| @@ -70,6 +70,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | ||||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | ||||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) { | |||||
| MS_EXCEPTION(NotExistsError) << "Call an empty function"; | |||||
| } | |||||
| // build graph, used to handle multiple child graphs | // build graph, used to handle multiple child graphs | ||||
| virtual void BuildGraph(GraphId) {} | virtual void BuildGraph(GraphId) {} | ||||
| @@ -129,6 +132,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); | void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); | ||||
| protected: | protected: | ||||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | |||||
| // Get graph by graph id ,if not exist return null ptr | // Get graph by graph id ,if not exist return null ptr | ||||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | KernelGraphPtr GetGraph(GraphId graph_id) const; | ||||
| @@ -37,10 +37,10 @@ constexpr auto kPadList = "pad_list"; | |||||
| constexpr auto kConv2DName = "Conv2D"; | constexpr auto kConv2DName = "Conv2D"; | ||||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto conv_prim = std::dynamic_pointer_cast<Conv2d>(primitive); | |||||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(conv_prim); | MS_EXCEPTION_IF_NULL(conv_prim); | ||||
| auto prim_name = conv_prim->name(); | auto prim_name = conv_prim->name(); | ||||
| CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); | |||||
| CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | ||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | ||||
| @@ -99,7 +99,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| } | } | ||||
| TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); | |||||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ class Conv2d : public PrimitiveC { | |||||
| public: | public: | ||||
| Conv2d(); | Conv2d(); | ||||
| ~Conv2d() = default; | ~Conv2d() = default; | ||||
| MS_DECLARE_PARENT(Conv2d, PrimitiveC); | |||||
| void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", | void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", | ||||
| const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, | const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, | ||||
| const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); | const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); | ||||
| @@ -25,6 +25,7 @@ namespace mindspore { | |||||
| class PrimitiveC : public Primitive { | class PrimitiveC : public Primitive { | ||||
| public: | public: | ||||
| explicit PrimitiveC(const std::string &name) : Primitive(name) {} | explicit PrimitiveC(const std::string &name) : Primitive(name) {} | ||||
| MS_DECLARE_PARENT(PrimitiveC, Primitive); | |||||
| ~PrimitiveC() = default; | ~PrimitiveC() = default; | ||||
| AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list); | AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list); | ||||
| @@ -640,7 +640,7 @@ CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<An | |||||
| return NewCNode(input_node_list); | return NewCNode(input_node_list); | ||||
| } | } | ||||
| ParameterPtr FuncGraph::add_parameter(const tensor::MetaTensorPtr &meta_tensor) { | |||||
| ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) { | |||||
| auto parameter = add_parameter(); | auto parameter = add_parameter(); | ||||
| parameter->set_default_param(MakeValue(meta_tensor)); | parameter->set_default_param(MakeValue(meta_tensor)); | ||||
| parameter->set_abstract(meta_tensor->ToAbstract()); | parameter->set_abstract(meta_tensor->ToAbstract()); | ||||
| @@ -173,7 +173,7 @@ class FuncGraph : public FuncGraphBase { | |||||
| CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope); | CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope); | ||||
| virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs); | virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs); | ||||
| virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor); | |||||
| virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor); | |||||
| // Functions for handling variable argument, keyword-only arguments and variable keyword argument | // Functions for handling variable argument, keyword-only arguments and variable keyword argument | ||||
| AnfNodePtr GetDefaultValueByName(const std::string &name); | AnfNodePtr GetDefaultValueByName(const std::string &name); | ||||
| void set_param_default_value(const std::string &name, const AnfNodePtr &node) { | void set_param_default_value(const std::string &name, const AnfNodePtr &node) { | ||||
| @@ -64,23 +64,36 @@ std::vector<int> CheckAndConvertUtils::CheckPositiveVector(const std::string &ar | |||||
| const std::vector<int> &arg_value, | const std::vector<int> &arg_value, | ||||
| const std::string &prim_name, bool allow_four, | const std::string &prim_name, bool allow_four, | ||||
| bool ret_four) { | bool ret_four) { | ||||
| auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void { | |||||
| std::ostringstream buffer; | |||||
| buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; | |||||
| if (allow_four) { | |||||
| buffer << "or four "; | |||||
| } | |||||
| buffer << " positive int numbers , but got ["; | |||||
| for (auto item : arg_value) { | |||||
| buffer << item << ","; | |||||
| } | |||||
| buffer << "]"; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| }; | |||||
| for (auto item : arg_value) { | |||||
| if (item < 0) { | |||||
| raise_message(); | |||||
| } | |||||
| } | |||||
| if (arg_value.size() == 1) { | |||||
| return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[0]} : std::vector<int>{arg_value[0], arg_value[0]}; | |||||
| } | |||||
| if (arg_value.size() == 2) { | if (arg_value.size() == 2) { | ||||
| return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value; | return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value; | ||||
| } else if (arg_value.size() == 4 && allow_four) { | } else if (arg_value.size() == 4 && allow_four) { | ||||
| return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]}; | return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]}; | ||||
| } | } | ||||
| std::ostringstream buffer; | |||||
| buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; | |||||
| if (allow_four) { | |||||
| buffer << "or four "; | |||||
| } | |||||
| buffer << " positive int numbers , but got ["; | |||||
| for (auto item : arg_value) { | |||||
| buffer << item << ","; | |||||
| } | |||||
| buffer << "]"; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| raise_message(); | |||||
| return arg_value; | |||||
| } | } | ||||
| std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value, | std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value, | ||||
| const std::set<std::string> &check_list, const std::string &prim_name) { | const std::set<std::string> &check_list, const std::string &prim_name) { | ||||
| if (check_list.find(arg_value) != check_list.end()) { | if (check_list.find(arg_value) != check_list.end()) { | ||||
| @@ -131,6 +144,10 @@ void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_val | |||||
| if (iter == kCompareRangeMap.end()) { | if (iter == kCompareRangeMap.end()) { | ||||
| MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; | MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; | ||||
| } | } | ||||
| if (range.first >= range.second) { | |||||
| MS_EXCEPTION(ArgumentError) << "the check range left must be larger than right number bug got [ " << range.first | |||||
| << "," << range.second; | |||||
| } | |||||
| if (iter->second(arg_value, range)) { | if (iter->second(arg_value, range)) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H | |||||
| #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H | |||||
| #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ | |||||
| #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | #include <map> | ||||
| @@ -67,4 +67,4 @@ class CheckAndConvertUtils { | |||||
| static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2); | static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2); | ||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H | |||||
| #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/tensor_construct_utils.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace { | |||||
| template <typename T> | |||||
| void SetTensorData(void *data, float num, size_t data_length) { | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| auto tensor_data = reinterpret_cast<T *>(data); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| for (size_t index = 0; index < data_length; ++index) { | |||||
| *tensor_data = num; | |||||
| ++tensor_data; | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int> &shape) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | |||||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | |||||
| auto tensor_data = tensor->data_c(); | |||||
| char *data = reinterpret_cast<char *>(tensor_data); | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| (void)memset_s(data, mem_size, 0, mem_size); | |||||
| return tensor; | |||||
| } | |||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int> &shape) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | |||||
| auto mem_size = IntToSize(tensor->ElementsNum()); | |||||
| if (tensor->data_type() == kNumberTypeFloat32) { | |||||
| SetTensorData<float>(tensor->data_c(), 1.0, mem_size); | |||||
| } else if (tensor->data_type() == kNumberTypeInt) { | |||||
| SetTensorData<int>(tensor->data_c(), 1, mem_size); | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector<int> &shape, void *data) { | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape, data, type); | |||||
| return tensor; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | |||||
| #define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | |||||
| #include <vector> | |||||
| #include "ir/tensor.h" | |||||
| namespace mindspore { | |||||
| class TensorConstructUtils { | |||||
| public: | |||||
| static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int> &shape); | |||||
| static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector<int> &shape); | |||||
| static tensor::TensorPtr CreateTensor(TypeId type, const std::vector<int> &shape, void *data); | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | |||||