| @@ -23,7 +23,7 @@ usage() | |||
| { | |||
| echo "Usage:" | |||
| echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" | |||
| echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|acl] \\" | |||
| echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|ascend310] \\" | |||
| echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\" | |||
| echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\" | |||
| echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\" | |||
| @@ -45,7 +45,7 @@ usage() | |||
| echo " -i Enable increment building, default off" | |||
| echo " -L Enable load ANF-IR as input of 'infer', default off" | |||
| echo " -j[n] Set the threads when building (Default: -j8)" | |||
| echo " -e Use cpu, gpu, ascend or acl" | |||
| echo " -e Use cpu, gpu, ascend or ascend310" | |||
| echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" | |||
| echo " -D Enable dumping of function graph ir, default on" | |||
| echo " -z Compile dataset & mindrecord, default on" | |||
| @@ -224,7 +224,7 @@ checkopts() | |||
| ENABLE_D="on" | |||
| ENABLE_CPU="on" | |||
| ENABLE_SERVING="on" | |||
| elif [[ "X$OPTARG" == "Xacl" ]]; then | |||
| elif [[ "X$OPTARG" == "Xascend310" ]]; then | |||
| ENABLE_SERVING="on" | |||
| ENABLE_ACL="on" | |||
| elif [[ "X$OPTARG" == "Xcpu" ]]; then | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include "include/api/status.h" | |||
| #include "include/api/types.h" | |||
| #include "include/api/graph.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| @@ -34,6 +35,7 @@ class MS_API CellBase { | |||
| virtual ~CellBase() = default; | |||
| virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; } | |||
| virtual std::shared_ptr<CellBase> Clone() const = 0; | |||
| virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; } | |||
| std::vector<Output> operator()(const std::vector<Input> &inputs) const; | |||
| }; | |||
| @@ -41,9 +43,7 @@ template <class T> | |||
| class MS_API Cell : public CellBase { | |||
| public: | |||
| virtual ~Cell() = default; | |||
| std::shared_ptr<CellBase> Clone() const override { | |||
| return std::make_shared<T>(static_cast<const T&>(*this)); | |||
| } | |||
| std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); } | |||
| }; | |||
| class MS_API ParameterCell final : public Cell<ParameterCell> { | |||
| @@ -84,9 +84,33 @@ class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T> | |||
| public: | |||
| explicit OpCell(const std::string &name) : OpCellBase(name) {} | |||
| ~OpCell() override = default; | |||
| std::shared_ptr<CellBase> Clone() const override { | |||
| return std::make_shared<T>(static_cast<const T&>(*this)); | |||
| } | |||
| std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); } | |||
| }; | |||
| class MS_API GraphCell final : public Cell<GraphCell> { | |||
| public: | |||
| class GraphImpl; | |||
| GraphCell() = default; | |||
| ~GraphCell() override = default; | |||
| explicit GraphCell(const Graph &); | |||
| explicit GraphCell(Graph &&); | |||
| explicit GraphCell(const std::shared_ptr<Graph> &); | |||
| const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | |||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; | |||
| private: | |||
| friend class ModelImpl; | |||
| Status Load(); | |||
| std::shared_ptr<Graph> graph_; | |||
| std::shared_ptr<GraphImpl> executor_; | |||
| }; | |||
| class MS_API InputAndOutput { | |||
| @@ -96,7 +120,7 @@ class MS_API InputAndOutput { | |||
| // no explicit | |||
| InputAndOutput(const Tensor &); // NOLINT(runtime/explicit) | |||
| InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) | |||
| InputAndOutput(Tensor &&); // NOLINT(runtime/explicit) | |||
| InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index); | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| #define MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| #include <string> | |||
| #include <memory> | |||
| #include "include/api/types.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| class MS_API Context { | |||
| public: | |||
| static Context &Instance(); | |||
| const std::string &GetDeviceTarget() const; | |||
| Context &SetDeviceTarget(const std::string &device_target); | |||
| uint32_t GetDeviceID() const; | |||
| Context &SetDeviceID(uint32_t device_id); | |||
| private: | |||
| Context(); | |||
| ~Context(); | |||
| class ContextImpl; | |||
| std::shared_ptr<ContextImpl> impl_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_INCLUDE_API_GRAPH_H | |||
| #define MINDSPORE_INCLUDE_API_GRAPH_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "include/api/status.h" | |||
| #include "include/api/types.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| class MS_API Graph { | |||
| public: | |||
| class GraphData; | |||
| explicit Graph(const std::shared_ptr<GraphData> &graph_data); | |||
| explicit Graph(std::shared_ptr<GraphData> &&graph_data); | |||
| enum ModelType ModelType() const; | |||
| private: | |||
| friend class GraphCell; | |||
| friend class ModelImpl; | |||
| std::shared_ptr<GraphData> graph_data_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_GRAPH_H | |||
| @@ -22,42 +22,39 @@ | |||
| #include <memory> | |||
| #include "include/api/status.h" | |||
| #include "include/api/types.h" | |||
| #include "include/api/graph.h" | |||
| #include "include/api/cell.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| class ModelImpl; | |||
| // todo: minddata c++ interface | |||
| class DataSet {}; | |||
| class NetWork {}; | |||
| class MS_API Model { | |||
| public: | |||
| Model(const std::string &device_type, uint32_t device_id); | |||
| Model(NetWork network, const std::string &device_type, uint32_t device_id); | |||
| explicit Model(const std::vector<Output> &network); | |||
| explicit Model(const GraphCell &graph); | |||
| ~Model(); | |||
| Model(const Model &) = delete; | |||
| void operator=(const Model &) = delete; | |||
| Status LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options); | |||
| Status LoadModel(const std::string &file_name, ModelType type, const std::map<std::string, std::string> &options); | |||
| Status UnloadModel(); | |||
| Status Build(const std::map<std::string, std::string> &options); | |||
| Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs); | |||
| Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs); | |||
| Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs); | |||
| Status Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs); | |||
| Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs); | |||
| Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs); | |||
| Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||
| Status GetInputsInfo(std::vector<Tensor> *tensor_list) const; | |||
| Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; | |||
| static bool CheckModelSupport(const std::string& device_type, ModelType model_type); | |||
| static bool CheckModelSupport(const std::string &device_type, ModelType model_type); | |||
| private: | |||
| std::shared_ptr<ModelImpl> impl_; | |||
| }; | |||
| extern MS_API const char* kDeviceTypeAscendCL; | |||
| extern MS_API const char* kDeviceTypeAscendMS; | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_MODEL_H | |||
| @@ -23,11 +23,13 @@ | |||
| #include "include/api/status.h" | |||
| #include "include/api/types.h" | |||
| #include "include/api/model.h" | |||
| #include "include/api/graph.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| class MS_API Serialization { | |||
| public: | |||
| static Graph LoadModel(const std::string &file, ModelType model_type); | |||
| static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters); | |||
| static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model); | |||
| static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); | |||
| @@ -102,6 +102,9 @@ class MS_API Buffer { | |||
| std::shared_ptr<Impl> impl_; | |||
| }; | |||
| extern MS_API const char *kDeviceTypeAscend310; | |||
| extern MS_API const char *kDeviceTypeAscend910; | |||
| constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; | |||
| constexpr auto kModelOptionDvppCfgPath = "mindspore.option.dvpp_config_file_path"; | |||
| constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file | |||
| @@ -6,22 +6,35 @@ set(LOAD_MINDIR_SRC | |||
| file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") | |||
| if (ENABLE_ACL) | |||
| file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc") | |||
| elseif (ENABLE_D) | |||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc") | |||
| add_compile_definitions(ENABLE_ACL) | |||
| include_directories(${CMAKE_SOURCE_DIR}/graphengine/src/ge) | |||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
| file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "model/acl/*.cc" | |||
| "model/model_converter_utils/*.cc" | |||
| "graph/acl/*.cc" | |||
| ) | |||
| endif () | |||
| if (ENABLE_D) | |||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/ms/*.cc") | |||
| endif () | |||
| set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc | |||
| ${API_MS_INFER_SRC} | |||
| ${API_ACL_SRC} | |||
| ${API_OPS_SRC} | |||
| ${LOAD_MINDIR_SRC}) | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/context.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/python_utils.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc | |||
| ${API_MS_INFER_SRC} | |||
| ${API_ACL_SRC} | |||
| ${API_OPS_SRC} | |||
| ${LOAD_MINDIR_SRC}) | |||
| add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) | |||
| set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") | |||
| set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} | |||
| -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) | |||
| @@ -69,5 +82,6 @@ endif () | |||
| if (ENABLE_D) | |||
| find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE mindspore_core hccl_adapter) | |||
| endif () | |||
| @@ -14,6 +14,9 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/cell.h" | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "cxx_api/graph/graph_impl.h" | |||
| namespace mindspore::api { | |||
| std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); } | |||
| @@ -51,6 +54,52 @@ ParameterCell &ParameterCell::operator=(Tensor &&tensor) { | |||
| return *this; | |||
| } | |||
| GraphCell::GraphCell(const Graph &graph) | |||
| : graph_(std::make_shared<Graph>(graph)), | |||
| executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->SetGraph(graph_); | |||
| } | |||
| GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) | |||
| : graph_(graph), | |||
| executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->SetGraph(graph_); | |||
| } | |||
| GraphCell::GraphCell(Graph &&graph) | |||
| : graph_(std::make_shared<Graph>(graph)), | |||
| executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| executor_->SetGraph(graph_); | |||
| } | |||
| Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->Run(inputs, outputs); | |||
| } | |||
| Status GraphCell::Load() { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->Load(); | |||
| } | |||
| Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} | |||
| InputAndOutput::InputAndOutput(const Tensor &tensor) | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/context.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::api { | |||
| class Context::ContextImpl { | |||
| public: | |||
| ContextImpl() : device_target_("NotSet"), device_id_(0) {} | |||
| const std::string &GetDeviceTarget() const { return device_target_; } | |||
| void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; } | |||
| uint32_t GetDeviceID() const { return device_id_; } | |||
| void SetDeviceID(uint32_t device_id) { device_id_ = device_id; } | |||
| private: | |||
| std::string device_target_; | |||
| uint32_t device_id_; | |||
| }; | |||
| Context &Context::Instance() { | |||
| static Context context; | |||
| return context; | |||
| } | |||
| const std::string &Context::GetDeviceTarget() const { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->GetDeviceTarget(); | |||
| } | |||
| Context &Context::SetDeviceTarget(const std::string &device_target) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| impl_->SetDeviceTarget(device_target); | |||
| return *this; | |||
| } | |||
| uint32_t Context::GetDeviceID() const { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->GetDeviceID(); | |||
| } | |||
| Context &Context::SetDeviceID(uint32_t device_id) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| impl_->SetDeviceID(device_id); | |||
| return *this; | |||
| } | |||
| Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); } | |||
| Context::~Context() {} | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H | |||
| #define MINDSPORE_CCSRC_CXX_API_FACTORY_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "utils/utils.h" | |||
| namespace mindspore::api { | |||
| template <class T> | |||
| class Factory { | |||
| using U = std::function<std::shared_ptr<T>()>; | |||
| public: | |||
| Factory(const Factory &) = delete; | |||
| void operator=(const Factory &) = delete; | |||
| static Factory &Instance() { | |||
| static Factory instance; | |||
| return instance; | |||
| } | |||
| void Register(const std::string &device_name, U &&creator) { | |||
| if (creators_.find(device_name) == creators_.end()) { | |||
| (void)creators_.emplace(device_name, creator); | |||
| } | |||
| } | |||
| bool CheckModelSupport(const std::string &device_name) { | |||
| return std::any_of(creators_.begin(), creators_.end(), | |||
| [&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; }); | |||
| } | |||
| std::shared_ptr<T> Create(const std::string &device_name) { | |||
| auto iter = creators_.find(device_name); | |||
| if (creators_.end() != iter) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| return (iter->second)(); | |||
| } | |||
| MS_LOG(ERROR) << "Unsupported device target " << device_name; | |||
| return nullptr; | |||
| } | |||
| private: | |||
| Factory() = default; | |||
| ~Factory() = default; | |||
| std::map<std::string, U> creators_; | |||
| }; | |||
| template <class T> | |||
| class Registrar { | |||
| using U = std::function<std::shared_ptr<T>()>; | |||
| public: | |||
| Registrar(const std::string &device_name, U creator) { | |||
| Factory<T>::Instance().Register(device_name, std::move(creator)); | |||
| } | |||
| ~Registrar() = default; | |||
| }; | |||
| #define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ | |||
| static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ | |||
| #DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); }); | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H | |||
| @@ -0,0 +1,266 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/graph/acl/acl_graph_impl.h" | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/model/acl/model_converter.h" | |||
| #include "cxx_api/python_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::api { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); | |||
| std::weak_ptr<AclGraphImpl::AclEnvGuard> AclGraphImpl::global_acl_env_; | |||
| std::mutex AclGraphImpl::global_acl_env_mutex_; | |||
| AclGraphImpl::AclGraphImpl() | |||
| : init_flag_(false), | |||
| load_flag_(false), | |||
| device_type_("AscendCL"), | |||
| device_id_(Context::Instance().GetDeviceID()), | |||
| context_(nullptr), | |||
| acl_env_(nullptr) {} | |||
| AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); } | |||
| Status AclGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Prepare model resource failed."; | |||
| return FAILED; | |||
| } | |||
| return model_process_.PredictFromHost(inputs, outputs); | |||
| } | |||
| Status AclGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Prepare model resource failed."; | |||
| return FAILED; | |||
| } | |||
| return model_process_.GetInputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| Status AclGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Prepare model resource failed."; | |||
| return FAILED; | |||
| } | |||
| return model_process_.GetOutputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| Status AclGraphImpl::LoadAclModel(Buffer om_data) { | |||
| MS_LOG(INFO) << "Start load acl model."; | |||
| // acl load model | |||
| uint32_t acl_model_id; | |||
| auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); | |||
| if (acl_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; | |||
| return FAILED; | |||
| } | |||
| // acl init model resource | |||
| model_process_.set_model_id(acl_model_id); | |||
| Status ret = model_process_.PreInitModelResource(); | |||
| if (ret != SUCCESS) { | |||
| (void)aclmdlUnload(acl_model_id); | |||
| MS_LOG(ERROR) << "Pre init model resource failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Load acl model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status AclGraphImpl::InitEnv() { | |||
| if (init_flag_) { | |||
| return SUCCESS; | |||
| } | |||
| aclError ret; | |||
| { | |||
| std::lock_guard<std::mutex> lock(global_acl_env_mutex_); | |||
| acl_env_ = global_acl_env_.lock(); | |||
| if (acl_env_ != nullptr) { | |||
| MS_LOG(INFO) << "Acl has been initialized, skip."; | |||
| } else { | |||
| acl_env_ = std::make_shared<AclEnvGuard>(""); | |||
| if (acl_env_->GetErrno() != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Execute aclInit Failed"; | |||
| return FAILED; | |||
| } | |||
| global_acl_env_ = acl_env_; | |||
| MS_LOG(INFO) << "Acl init success"; | |||
| } | |||
| } | |||
| ret = aclrtSetDevice(device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Open device " << device_id_ << " success"; | |||
| ret = aclrtCreateContext(&context_, device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl create context failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Create context success"; | |||
| aclrtRunMode run_mode; | |||
| ret = aclrtGetRunMode(&run_mode); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl get run mode failed"; | |||
| return FAILED; | |||
| } | |||
| bool is_device = (run_mode == ACL_DEVICE); | |||
| model_process_.SetIsDevice(is_device); | |||
| MS_LOG(INFO) << "Get run mode success is device input/output " << is_device; | |||
| MS_LOG(INFO) << "Init acl success, device id " << device_id_; | |||
| init_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status AclGraphImpl::FinalizeEnv() { | |||
| if (!init_flag_) { | |||
| return SUCCESS; | |||
| } | |||
| aclError rt_ret = aclrtSetCurrentContext(context_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Set the ascend device context failed"; | |||
| return FAILED; | |||
| } | |||
| Status ret = model_process_.UnLoad(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Unload model inner failed."; | |||
| return FAILED; | |||
| } | |||
| if (context_ != nullptr) { | |||
| rt_ret = aclrtDestroyContext(context_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy context failed"; | |||
| } | |||
| context_ = nullptr; | |||
| } | |||
| MS_LOG(INFO) << "End to destroy context"; | |||
| rt_ret = aclrtResetDevice(device_id_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Reset device " << device_id_ << " failed"; | |||
| } | |||
| MS_LOG(INFO) << "End to reset device " << device_id_; | |||
| init_flag_ = false; | |||
| return SUCCESS; | |||
| } | |||
| Status AclGraphImpl::Load() { | |||
| // check graph type | |||
| if (graph_->ModelType() != ModelType::kOM) { | |||
| Status ret = ConvertToOM(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Load Failed."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| const auto &graph_data = GraphImpl::MutableGraphData(); | |||
| MS_EXCEPTION_IF_NULL(graph_data); | |||
| auto om_data = graph_data->GetOMData(); | |||
| // init | |||
| Status ret = InitEnv(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "InitEnv failed."; | |||
| return FAILED; | |||
| } | |||
| // load model | |||
| if (!load_flag_) { | |||
| ret = LoadAclModel(om_data); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Load acl model failed."; | |||
| return ret; | |||
| } | |||
| load_flag_ = true; | |||
| } | |||
| aclError rt_ret = aclrtSetCurrentContext(context_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Set the ascend device context failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status AclGraphImpl::ConvertToOM() { | |||
| MS_LOG(INFO) << "Start convert to om model."; | |||
| RegAllOpFromPython(); | |||
| if (graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid graph_ is null."; | |||
| return FAILED; | |||
| } | |||
| auto &graph_data = GraphImpl::MutableGraphData(); | |||
| MS_EXCEPTION_IF_NULL(graph_data); | |||
| if (graph_->ModelType() == ModelType::kOM) { | |||
| MS_LOG(INFO) << "This model has been built, skip."; | |||
| return SUCCESS; | |||
| } else if (graph_->ModelType() == ModelType::kMindIR) { | |||
| auto func_graph = graph_data->GetFuncGraph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| ModelConverter model_converter; | |||
| Buffer om_data = model_converter.LoadMindIR(func_graph); | |||
| if (om_data.Data() == nullptr || om_data.DataSize() == 0) { | |||
| MS_LOG(ERROR) << "Convert MindIR to OM failed."; | |||
| return FAILED; | |||
| } | |||
| graph_data = std::make_shared<Graph::GraphData>(om_data, ModelType::kOM); | |||
| MS_LOG(INFO) << "Convert MindIR to OM success."; | |||
| return SUCCESS; | |||
| } | |||
| MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); | |||
| return FAILED; | |||
| } | |||
| AclGraphImpl::AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { | |||
| errno_ = aclInit(cfg_file.data()); | |||
| if (errno_ != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Execute aclInit Failed"; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Acl init success"; | |||
| } | |||
| AclGraphImpl::AclEnvGuard::~AclEnvGuard() { | |||
| errno_ = aclFinalize(); | |||
| if (errno_ != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Finalize acl failed"; | |||
| } | |||
| MS_LOG(INFO) << "Acl finalize success"; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/acl/model_process.h" | |||
| #include "cxx_api/graph/graph_impl.h" | |||
| #include "cxx_api/factory.h" | |||
| namespace mindspore::api { | |||
| class AclGraphImpl : public GraphCell::GraphImpl { | |||
| public: | |||
| AclGraphImpl(); | |||
| ~AclGraphImpl() override; | |||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status Load() override; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||
| private: | |||
| class AclEnvGuard; | |||
| Status ConvertToOM(); | |||
| Status InitEnv(); | |||
| Status FinalizeEnv(); | |||
| Status LoadAclModel(Buffer om_data); | |||
| bool init_flag_; | |||
| bool load_flag_; | |||
| std::string device_type_; | |||
| int32_t device_id_; | |||
| aclrtContext context_; | |||
| std::shared_ptr<AclEnvGuard> acl_env_; | |||
| static std::weak_ptr<AclEnvGuard> global_acl_env_; | |||
| static std::mutex global_acl_env_mutex_; | |||
| ModelProcess model_process_; | |||
| }; | |||
| class AclGraphImpl::AclEnvGuard { | |||
| public: | |||
| explicit AclEnvGuard(std::string_view cfg_file); | |||
| ~AclEnvGuard(); | |||
| aclError GetErrno() const { return errno_; } | |||
| private: | |||
| aclError errno_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/model/acl/model_process.h" | |||
| #include "cxx_api/graph/acl/model_process.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include "utils/utils.h" | |||
| @@ -35,17 +35,33 @@ static DataType TransToApiType(aclDataType data_type) { | |||
| } | |||
| } | |||
| static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<Tensor> *tensor_list) { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| tensor_list->clear(); | |||
| template <class T> | |||
| inline static void ClearIfNotNull(T *vec) { | |||
| if (vec != nullptr) { | |||
| vec->clear(); | |||
| } | |||
| } | |||
| template <class T, class U = std::vector<T>> | |||
| inline static void PushbackIfNotNull(U *vec, T &&item) { | |||
| if (vec != nullptr) { | |||
| vec->emplace_back(item); | |||
| } | |||
| } | |||
| static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names, | |||
| std::vector<std::vector<int64_t>> *shapes, std::vector<DataType> *data_types, | |||
| std::vector<size_t> *mem_sizes) { | |||
| ClearIfNotNull(names); | |||
| ClearIfNotNull(shapes); | |||
| ClearIfNotNull(data_types); | |||
| ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < acl_tensor_list.size(); ++i) { | |||
| const auto &info = acl_tensor_list[i]; | |||
| Tensor tensor_desc; | |||
| tensor_desc.SetName(info.name); | |||
| tensor_desc.SetDataType(TransToApiType(info.data_type)); | |||
| tensor_desc.SetShape(info.dims); | |||
| tensor_list->push_back(tensor_desc); | |||
| PushbackIfNotNull(names, info.name); | |||
| PushbackIfNotNull(shapes, info.dims); | |||
| PushbackIfNotNull(data_types, TransToApiType(info.data_type)); | |||
| PushbackIfNotNull(mem_sizes, info.buffer_size); | |||
| } | |||
| } | |||
| @@ -272,7 +288,7 @@ Status ModelProcess::UnLoad() { | |||
| return SUCCESS; | |||
| } | |||
| Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inputs) { | |||
| Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) { | |||
| aclError ret; | |||
| inputs_ = aclmdlCreateDataset(); | |||
| // check inputs | |||
| @@ -282,29 +298,16 @@ Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inpu | |||
| return INVALID_INPUTS; | |||
| } | |||
| for (size_t i = 0; i < input_infos_.size(); ++i) { | |||
| const std::string &input_name = input_infos_[i].name; | |||
| auto iter = inputs.find(input_name); | |||
| if (iter == inputs.end()) { | |||
| MS_LOG(ERROR) << "Model missing input " << input_name; | |||
| return INVALID_INPUTS; | |||
| } | |||
| if (iter->second.DataSize() != input_infos_[i].buffer_size) { | |||
| if (inputs[i].DataSize() != input_infos_[i].buffer_size) { | |||
| MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size | |||
| << ", given count " << iter->second.DataSize(); | |||
| << ", given count " << inputs[i].DataSize(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| } | |||
| // copy inputs | |||
| for (size_t i = 0; i < input_infos_.size(); ++i) { | |||
| const auto &info = input_infos_[i]; | |||
| auto iter = inputs.find(info.name); | |||
| if (iter == inputs.end()) { | |||
| MS_LOG(ERROR) << "Model missing input " << info.name; | |||
| return INVALID_INPUTS; | |||
| } | |||
| const auto &input = iter->second; | |||
| const auto &input = inputs[i]; | |||
| const void *data = input.Data(); | |||
| void *input_buffer = nullptr; | |||
| @@ -333,42 +336,7 @@ Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inpu | |||
| return SUCCESS; | |||
| } | |||
| Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, | |||
| size_t input_index) { | |||
| aclError ret; | |||
| inputs_ = aclmdlCreateDataset(); | |||
| // check inputs | |||
| if (input_index >= input_infos_.size()) { | |||
| MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given index " | |||
| << input_index; | |||
| return INVALID_INPUTS; | |||
| } | |||
| if (dvpp_outputs_buffer_dev == nullptr) { | |||
| MS_LOG(ERROR) << "input " << 0 << " cannot be null"; | |||
| return FAILED; | |||
| } | |||
| if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) { | |||
| MS_LOG(ERROR) << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size | |||
| << ", given count " << dvpp_outputs_buffer_size; | |||
| return INVALID_INPUTS; | |||
| } | |||
| // copy inputs | |||
| auto &info = input_infos_[input_index]; | |||
| auto data_buffer = aclCreateDataBuffer(const_cast<void *>(dvpp_outputs_buffer_dev), info.buffer_size); | |||
| if (data_buffer == nullptr) { | |||
| MS_LOG(ERROR) << "Create Data Buffer failed"; | |||
| return FAILED; | |||
| } | |||
| ret = aclmdlAddDatasetBuffer(inputs_, data_buffer); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "add data buffer failed"; | |||
| aclDestroyDataBuffer(data_buffer); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) { | |||
| Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| aclError acl_ret; | |||
| Status ret = CheckAndInitInput(inputs); | |||
| @@ -392,18 +360,7 @@ Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::m | |||
| return SUCCESS; | |||
| } | |||
| size_t ModelProcess::GetBatchSize() const { | |||
| if (input_infos_.empty()) { | |||
| MS_LOG(ERROR) << "Model is not loaded"; | |||
| return 0; | |||
| } | |||
| if (input_infos_[0].dims.empty()) { | |||
| return 1; | |||
| } | |||
| return static_cast<size_t>(input_infos_[0].dims[0]); | |||
| } | |||
| Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) { | |||
| Status ModelProcess::BuildOutputs(std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| aclError ret; | |||
| // copy outputs | |||
| @@ -411,14 +368,13 @@ Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) { | |||
| aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST; | |||
| for (size_t i = 0; i < output_infos_.size(); ++i) { | |||
| const auto &info = output_infos_[i]; | |||
| // todo | |||
| outputs->emplace(info.name, Buffer()); | |||
| auto output = outputs->rbegin()->second; | |||
| if (!output.ResizeData(info.buffer_size)) { | |||
| outputs->emplace_back(Buffer()); | |||
| auto output = outputs->rbegin(); | |||
| if (!output->ResizeData(info.buffer_size)) { | |||
| MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size; | |||
| return FAILED; | |||
| } | |||
| ret = aclrtMemcpy(output.MutableData(), output.DataSize(), info.device_data, info.buffer_size, kind); | |||
| ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device") | |||
| << " to host failed, memory size " << info.buffer_size; | |||
| @@ -428,13 +384,15 @@ Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) { | |||
| return SUCCESS; | |||
| } | |||
| Status ModelProcess::GetInputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| ConstructTensorDesc(input_infos_, tensor_list); | |||
| Status ModelProcess::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes); | |||
| return SUCCESS; | |||
| } | |||
| Status ModelProcess::GetOutputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| ConstructTensorDesc(output_infos_, tensor_list); | |||
| Status ModelProcess::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H | |||
| #define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H | |||
| #ifndef MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H | |||
| #define MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| @@ -34,12 +34,6 @@ struct AclTensorInfo { | |||
| std::string name; | |||
| }; | |||
| struct ImagesDvppOutput { | |||
| void *buffer_device = nullptr; | |||
| size_t buffer_size = 0; | |||
| size_t input_index = 0; | |||
| }; | |||
| class ModelProcess { | |||
| public: | |||
| ModelProcess() | |||
| @@ -53,24 +47,23 @@ class ModelProcess { | |||
| ~ModelProcess() {} | |||
| Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id); | |||
| Status UnLoad(); | |||
| Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs); | |||
| Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||
| Status PreInitModelResource(); | |||
| Status GetInputsInfo(std::vector<Tensor> *tensor_list) const; | |||
| Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const; | |||
| // override this method to avoid request/reply data copy | |||
| void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } | |||
| size_t GetBatchSize() const; | |||
| void set_model_id(uint32_t model_id) { model_id_ = model_id; } | |||
| uint32_t model_id() const { return model_id_; } | |||
| private: | |||
| Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); | |||
| Status CheckAndInitInput(const std::map<std::string, Buffer> &inputs); | |||
| Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, | |||
| size_t input_index); | |||
| Status BuildOutputs(std::map<std::string, Buffer> *outputs); | |||
| Status CheckAndInitInput(const std::vector<Buffer> &inputs); | |||
| Status BuildOutputs(std::vector<Buffer> *outputs); | |||
| Status InitInputsBuffer(); | |||
| Status InitOutputsBuffer(); | |||
| @@ -90,4 +83,4 @@ class ModelProcess { | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H | |||
| #endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::api { | |||
| Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {} | |||
| Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {} | |||
| ModelType Graph::ModelType() const { | |||
| MS_EXCEPTION_IF_NULL(graph_data_); | |||
| return graph_data_->ModelType(); | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/log_adapter.h" | |||
| #ifdef ENABLE_ACL | |||
| #include "framework/common/helper/model_helper.h" | |||
| #endif | |||
| namespace mindspore::api { | |||
| Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type) | |||
| : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { | |||
| if (model_type != ModelType::kMindIR) { | |||
| MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type; | |||
| } | |||
| func_graph_ = func_graph; | |||
| model_type_ = model_type; | |||
| } | |||
| Graph::GraphData::GraphData(Buffer om_data, enum ModelType model_type) | |||
| : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) { | |||
| if (model_type != ModelType::kOM) { | |||
| MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type; | |||
| } | |||
| #ifdef ENABLE_ACL | |||
| // check om | |||
| ge::ModelHelper helper; | |||
| ge::ModelData model_data; | |||
| model_data.model_data = om_data.MutableData(); | |||
| model_data.model_len = om_data.DataSize(); | |||
| ge::Status ret = helper.LoadModel(model_data); | |||
| if (ret != ge::SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om."; | |||
| } | |||
| om_data_ = om_data; | |||
| model_type_ = model_type; | |||
| #else | |||
| MS_LOG(EXCEPTION) << "Unsupported ModelType OM."; | |||
| #endif | |||
| } | |||
| FuncGraphPtr Graph::GraphData::GetFuncGraph() const { | |||
| if (model_type_ != ModelType::kMindIR) { | |||
| MS_LOG(ERROR) << "Invalid ModelType " << model_type_; | |||
| return nullptr; | |||
| } | |||
| return func_graph_; | |||
| } | |||
| Buffer Graph::GraphData::GetOMData() const { | |||
| if (model_type_ != ModelType::kOM) { | |||
| MS_LOG(ERROR) << "Invalid ModelType " << model_type_; | |||
| return Buffer(); | |||
| } | |||
| return om_data_; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "include/api/graph.h" | |||
| #include "include/api/types.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::api { | |||
| class Graph::GraphData { | |||
| public: | |||
| GraphData(); | |||
| explicit GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type = kMindIR); | |||
| GraphData(Buffer om_data, enum ModelType model_type); | |||
| enum ModelType ModelType() const { return model_type_; } | |||
| FuncGraphPtr GetFuncGraph() const; | |||
| Buffer GetOMData() const; | |||
| private: | |||
| FuncGraphPtr func_graph_; | |||
| Buffer om_data_; | |||
| enum ModelType model_type_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "include/api/cell.h" | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore::api { | |||
| class GraphCell::GraphImpl { | |||
| public: | |||
| GraphImpl() = default; | |||
| virtual ~GraphImpl() = default; | |||
| std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | |||
| void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | |||
| virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0; | |||
| virtual Status Load() = 0; | |||
| virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0; | |||
| virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0; | |||
| protected: | |||
| std::shared_ptr<Graph> graph_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | |||
| @@ -0,0 +1,334 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/graph/ms/ms_graph_impl.h" | |||
| #include <algorithm> | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "cxx_api/python_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/context/context_extends.h" | |||
| #include "mindspore/core/base/base_ref_utils.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "backend/session/executor_manager.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| namespace mindspore::api { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); | |||
| static DataType TransTypeId2InferDataType(TypeId type_id) { | |||
| const std::map<TypeId, api::DataType> id2type_map{ | |||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||
| }; | |||
| // cppcheck-suppress stlIfFind | |||
| if (auto it = id2type_map.find(type_id); it != id2type_map.end()) { | |||
| return it->second; | |||
| } | |||
| MS_LOG(WARNING) << "Unsupported data id " << type_id; | |||
| return api::kMsUnknown; | |||
| } | |||
| template <class T> | |||
| inline static void ClearIfNotNull(T *vec) { | |||
| if (vec != nullptr) { | |||
| vec->clear(); | |||
| } | |||
| } | |||
| template <class T, class U = std::vector<T>> | |||
| inline static void PushbackIfNotNull(U *vec, T &&item) { | |||
| if (vec != nullptr) { | |||
| vec->emplace_back(item); | |||
| } | |||
| } | |||
| MsGraphImpl::MsGraphImpl() | |||
| : session_impl_(nullptr), | |||
| graph_id_(0), | |||
| device_type_("Ascend"), | |||
| device_id_(Context::Instance().GetDeviceID()), | |||
| context_(nullptr), | |||
| inputs_(), | |||
| outputs_(), | |||
| input_names_(), | |||
| output_names_(), | |||
| load_flag_(false) {} | |||
| MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); } | |||
| Status MsGraphImpl::InitEnv() { | |||
| RegAllOpFromPython(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | |||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice); | |||
| if (!context::OpenTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "Session init OpenTsd failed!"; | |||
| return FAILED; | |||
| } | |||
| session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice | |||
| << " is available."; | |||
| return FAILED; | |||
| } | |||
| session_impl_->Init(device_id_); | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::FinalizeEnv() { | |||
| MS_LOG_INFO << "Start finalize env"; | |||
| pybind11::gil_scoped_acquire acquire; | |||
| session::ExecutorManager::Instance().Clear(); | |||
| device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| if (!context::CloseTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "CloseTsd failed!"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "End finalize env"; | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| pybind11::gil_scoped_release gil_release; | |||
| return SUCCESS; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); | |||
| return FAILED; | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||
| try { | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "RunGraph failed: " << e.what(); | |||
| return std::vector<tensor::TensorPtr>(); | |||
| } | |||
| } | |||
| Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| std::string error_msg; | |||
| if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { | |||
| return Status(INVALID_INPUTS, error_msg); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||
| MS_EXCEPTION_IF_NULL(reply); | |||
| if (context_ == nullptr) { | |||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | |||
| return FAILED; | |||
| } | |||
| rtError_t rt_ret = rtCtxSetCurrent(context_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Set Ascend rtCtx failed"; | |||
| return FAILED; | |||
| } | |||
| vector<tensor::TensorPtr> inputs; | |||
| for (size_t i = 0; i < request.size(); i++) { | |||
| auto &item = request[i]; | |||
| auto input = inputs_[i]; | |||
| if (input->Size() != item.DataSize()) { | |||
| MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " | |||
| << input->Size(); | |||
| return FAILED; | |||
| } | |||
| auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Tensor copy failed"; | |||
| return FAILED; | |||
| } | |||
| inputs.push_back(input); | |||
| } | |||
| vector<tensor::TensorPtr> outputs = RunGraph(inputs); | |||
| if (outputs.empty()) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| reply->clear(); | |||
| std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), | |||
| [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| ClearIfNotNull(names); | |||
| ClearIfNotNull(shapes); | |||
| ClearIfNotNull(data_types); | |||
| ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto &tensor = inputs_[i]; | |||
| PushbackIfNotNull(names, input_names_[i]); | |||
| PushbackIfNotNull(shapes, tensor->shape()); | |||
| PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); | |||
| PushbackIfNotNull(mem_sizes, tensor->DataSize()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| ClearIfNotNull(names); | |||
| ClearIfNotNull(shapes); | |||
| ClearIfNotNull(data_types); | |||
| ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < outputs_.size(); i++) { | |||
| auto &tensor = outputs_[i]; | |||
| PushbackIfNotNull(names, output_names_[i]); | |||
| PushbackIfNotNull(shapes, tensor->shape()); | |||
| PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); | |||
| PushbackIfNotNull(mem_sizes, tensor->DataSize()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::Load() { | |||
| // check graph type | |||
| if (graph_->ModelType() != ModelType::kMindIR) { | |||
| MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| const auto &graph_data = GraphImpl::MutableGraphData(); | |||
| MS_EXCEPTION_IF_NULL(graph_data); | |||
| auto func_graph = graph_data->GetFuncGraph(); | |||
| // init | |||
| Status ret = InitEnv(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "InitEnv failed."; | |||
| return FAILED; | |||
| } | |||
| // load model | |||
| if (!load_flag_) { | |||
| ret = CompileGraph(func_graph); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Compile graph model failed"; | |||
| return FAILED; | |||
| } | |||
| session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); | |||
| session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); | |||
| if (inputs_.empty() || inputs_.size() != input_names_.size()) { | |||
| MS_LOG_ERROR << "Get model inputs info failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.empty() || outputs_.size() != output_names_.size()) { | |||
| MS_LOG_ERROR << "Get model outputs info failed"; | |||
| return FAILED; | |||
| } | |||
| // save d context | |||
| rtError_t rt_ret = rtCtxGetCurrent(&context_); | |||
| if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { | |||
| MS_LOG(ERROR) << "the ascend device context is null"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Load model success"; | |||
| load_flag_ = true; | |||
| } | |||
| rtError_t rt_ret = rtCtxSetCurrent(context_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Set the ascend device context failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| if (inputs.size() != inputs_.size()) { | |||
| MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||
| if (inputs[i].DataSize() != inputs_[i]->Size()) { | |||
| MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " | |||
| << inputs[i].DataSize(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| } | |||
| if (ExecuteModel(inputs, outputs) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.size() != outputs->size()) { | |||
| MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " | |||
| << outputs_.size(); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "include/api/status.h" | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_impl.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "ir/anf.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #include "runtime/context.h" | |||
| namespace mindspore::api { | |||
| class MsGraphImpl : public GraphCell::GraphImpl { | |||
| public: | |||
| MsGraphImpl(); | |||
| ~MsGraphImpl() override; | |||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status Load() override; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||
| private: | |||
| Status InitEnv(); | |||
| Status FinalizeEnv(); | |||
| Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr); | |||
| Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const; | |||
| std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); | |||
| Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||
| std::shared_ptr<session::SessionBasic> session_impl_; | |||
| uint32_t graph_id_; | |||
| std::string device_type_; | |||
| uint32_t device_id_; | |||
| rtContext_t context_; | |||
| std::vector<tensor::TensorPtr> inputs_; | |||
| std::vector<tensor::TensorPtr> outputs_; | |||
| std::vector<std::string> input_names_; | |||
| std::vector<std::string> output_names_; | |||
| bool load_flag_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| @@ -16,216 +16,57 @@ | |||
| #include "cxx_api/model/acl/acl_model.h" | |||
| #include <memory> | |||
| #include "utils/context/context_extends.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "cxx_api/python_utils.h" | |||
| namespace mindspore::api { | |||
| std::weak_ptr<AclModel::AclEnvGuard> AclModel::global_acl_env_; | |||
| std::mutex AclModel::global_acl_env_mutex_; | |||
| Status AclModel::InitEnv() { | |||
| if (init_flag_) { | |||
| API_FACTORY_REG(ModelImpl, Ascend310, AclModel); | |||
| Status AclModel::Build(const std::map<std::string, std::string> &options_map) { | |||
| MS_LOG(INFO) << "Start build model."; | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| RegAllOpFromPython(); | |||
| std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(options_map); | |||
| std::string options_str = GenerateOptionsStr(options_map); | |||
| MS_EXCEPTION_IF_NULL(options); | |||
| if (graph_cell_ != nullptr && options_str == options_str_) { | |||
| MS_LOG(INFO) << "This model has been built, skip."; | |||
| return SUCCESS; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(options_); | |||
| aclError ret; | |||
| { | |||
| std::lock_guard<std::mutex> lock(global_acl_env_mutex_); | |||
| acl_env_ = global_acl_env_.lock(); | |||
| if (acl_env_ != nullptr) { | |||
| if (options_->dump_cfg_path.empty()) { | |||
| MS_LOG(INFO) << "Acl has been initialized, skip."; | |||
| } else { | |||
| MS_LOG(WARNING) << "Acl has been initialized, skip, so dump config will be ignored."; | |||
| } | |||
| } else { | |||
| acl_env_ = std::make_shared<AclEnvGuard>(options_->dump_cfg_path); | |||
| if (acl_env_->GetErrno() != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Execute aclInit Failed"; | |||
| return FAILED; | |||
| } | |||
| global_acl_env_ = acl_env_; | |||
| MS_LOG(INFO) << "Acl init success"; | |||
| if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) { | |||
| graph_cell_ = std::make_shared<GraphCell>(graph_); | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| if (!options_map.empty()) { | |||
| MS_LOG(WARNING) << "All build options will be ignored."; | |||
| } | |||
| } | |||
| ret = aclrtSetDevice(device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Open device " << device_id_ << " success"; | |||
| ret = aclrtCreateContext(&context_, device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl create context failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Create context success"; | |||
| ret = aclrtSetCurrentContext(context_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl set current context failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Set context success"; | |||
| ret = aclrtCreateStream(&stream_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl create stream failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Create stream success"; | |||
| aclrtRunMode run_mode; | |||
| ret = aclrtGetRunMode(&run_mode); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Acl get run mode failed"; | |||
| return FAILED; | |||
| } | |||
| bool is_device = (run_mode == ACL_DEVICE); | |||
| model_process_.SetIsDevice(is_device); | |||
| MS_LOG(INFO) << "Get run mode success is device input/output " << is_device; | |||
| if (dvpp_process_.InitResource(stream_) != SUCCESS) { | |||
| MS_LOG(ERROR) << "DVPP init resource failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Init acl success, device id " << device_id_; | |||
| init_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status AclModel::FinalizeEnv() { | |||
| if (!init_flag_) { | |||
| return SUCCESS; | |||
| } | |||
| dvpp_process_.Finalize(); | |||
| aclError ret; | |||
| if (stream_ != nullptr) { | |||
| ret = aclrtDestroyStream(stream_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy stream failed"; | |||
| } | |||
| stream_ = nullptr; | |||
| } | |||
| MS_LOG(INFO) << "End to destroy stream"; | |||
| if (context_ != nullptr) { | |||
| ret = aclrtDestroyContext(context_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy context failed"; | |||
| } | |||
| context_ = nullptr; | |||
| } | |||
| MS_LOG(INFO) << "End to destroy context"; | |||
| ret = aclrtResetDevice(device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Reset devie " << device_id_ << " failed"; | |||
| } | |||
| MS_LOG(INFO) << "End to reset device " << device_id_; | |||
| init_flag_ = false; | |||
| return SUCCESS; | |||
| } | |||
| Status AclModel::LoadModel(const Buffer &model_data, ModelType type, | |||
| const std::map<std::string, std::string> &options) { | |||
| if (load_flag_) { | |||
| MS_LOG(ERROR) << "Model has been loaded."; | |||
| auto func_graph = ModelImpl::GetFuncGraph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| model_converter_.set_options(options.get()); | |||
| auto om_data = model_converter_.LoadMindIR(func_graph); | |||
| if (om_data.Data() == nullptr || om_data.DataSize() == 0) { | |||
| MS_LOG(ERROR) << "Load MindIR failed."; | |||
| return FAILED; | |||
| } | |||
| options_ = std::make_unique<AclModelOptions>(options); | |||
| MS_EXCEPTION_IF_NULL(options_); | |||
| Status ret = InitEnv(); | |||
| auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM)); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto graph_cell = std::make_shared<GraphCell>(graph); | |||
| MS_EXCEPTION_IF_NULL(graph_cell); | |||
| auto ret = ModelImpl::Load(graph_cell); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "InitEnv failed."; | |||
| return FAILED; | |||
| MS_LOG(ERROR) << "Load failed."; | |||
| return ret; | |||
| } | |||
| Buffer om_data; | |||
| if (type == ModelType::kMindIR) { | |||
| model_converter_.set_options(options_.get()); | |||
| om_data = model_converter_.LoadMindIR(model_data); | |||
| } else if (type == ModelType::kAIR) { | |||
| model_converter_.set_options(options_.get()); | |||
| om_data = model_converter_.LoadAscendIR(model_data); | |||
| } else if (type == ModelType::kOM) { | |||
| om_data = model_data; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported model type " << type; | |||
| return FAILED; | |||
| } | |||
| // acl load model | |||
| uint32_t acl_model_id; | |||
| auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id); | |||
| if (acl_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; | |||
| return FAILED; | |||
| } | |||
| // acl init model resource | |||
| model_process_.set_model_id(acl_model_id); | |||
| ret = model_process_.PreInitModelResource(); | |||
| if (ret != SUCCESS) { | |||
| (void)aclmdlUnload(acl_model_id); | |||
| MS_LOG(ERROR) << "Pre init model resource failed."; | |||
| return FAILED; | |||
| } | |||
| // acl init dvpp | |||
| ret = dvpp_process_.InitWithJsonConfig(options_->dvpp_cfg_path); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "DVPP config file parse error."; | |||
| return FAILED; | |||
| } | |||
| load_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status AclModel::LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) { | |||
| Buffer model_data = ModelConverter::ReadFile(file_name); | |||
| if (model_data.DataSize() == 0) { | |||
| MS_LOG(ERROR) << "Read file " << file_name << " failed."; | |||
| return FAILED; | |||
| } | |||
| return LoadModel(model_data, type, options); | |||
| } | |||
| Status AclModel::UnloadModel() { | |||
| if (!load_flag_) { | |||
| MS_LOG(WARNING) << "No model is loaded, skip unload."; | |||
| return SUCCESS; | |||
| } | |||
| aclError rt_ret = aclrtSetCurrentContext(context_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Set the ascend device context failed"; | |||
| return FAILED; | |||
| } | |||
| Status ret = model_process_.UnLoad(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Unload model inner failed."; | |||
| return FAILED; | |||
| } | |||
| ret = FinalizeEnv(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "FinalizeEnv failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Unload model success."; | |||
| load_flag_ = false; | |||
| // save result | |||
| graph_cell_ = graph_cell; | |||
| options_ = std::move(options); | |||
| options_str_ = options_str; | |||
| MS_LOG(INFO) << "Build model success."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -239,45 +80,49 @@ Status AclModel::Eval(const DataSet &, std::map<std::string, Buffer> *) { | |||
| return FAILED; | |||
| } | |||
| Status AclModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) { | |||
| Status AclModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!load_flag_) { | |||
| MS_LOG(ERROR) << "No model is loaded, predict failed."; | |||
| if (graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid data, graph_ is null."; | |||
| return FAILED; | |||
| } | |||
| aclError rt_ret = aclrtSetCurrentContext(context_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Set the ascend device context failed"; | |||
| if (graph_cell_ == nullptr) { | |||
| MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; | |||
| Status ret = Build({}); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Build model failed."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| Status ret = graph_cell_->Run(inputs, outputs); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Run graph failed."; | |||
| return FAILED; | |||
| } | |||
| return model_process_.Predict(inputs, outputs); | |||
| } | |||
| Status AclModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| return model_process_.GetInputsInfo(tensor_list); | |||
| return SUCCESS; | |||
| } | |||
| Status AclModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| return model_process_.GetOutputsInfo(tensor_list); | |||
| Status AclModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| AclModel::AclEnvGuard::AclEnvGuard(const std::string &cfg_file) { | |||
| errno_ = aclInit(common::SafeCStr(cfg_file)); | |||
| if (errno_ != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Execute aclInit Failed"; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Acl init success"; | |||
| Status AclModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| AclModel::AclEnvGuard::~AclEnvGuard() { | |||
| errno_ = aclFinalize(); | |||
| if (errno_ != ACL_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Finalize acl failed"; | |||
| std::string AclModel::GenerateOptionsStr(const std::map<std::string, std::string> &options) { | |||
| std::string ret; | |||
| for (auto &[key, value] : options) { | |||
| ret += key + "^" + value + "^^"; | |||
| } | |||
| MS_LOG(INFO) << "Acl finalize success"; | |||
| return ret; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -23,77 +23,38 @@ | |||
| #include <memory> | |||
| #include <map> | |||
| #include "ir/anf.h" | |||
| #include "include/api/cell.h" | |||
| #include "include/api/status.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #include "cxx_api/model/acl/dvpp_process.h" | |||
| #include "cxx_api/model/acl/model_process.h" | |||
| #include "cxx_api/model/acl/model_converter.h" | |||
| #include "cxx_api/model/acl/acl_model_options.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore::api { | |||
| class AclModel : public ModelImpl { | |||
| public: | |||
| explicit AclModel(uint32_t device_id) | |||
| : init_flag_(false), | |||
| load_flag_(false), | |||
| device_type_("AscendCL"), | |||
| device_id_(device_id), | |||
| context_(nullptr), | |||
| stream_(nullptr), | |||
| acl_env_(nullptr), | |||
| model_process_(), | |||
| dvpp_process_(), | |||
| model_converter_(), | |||
| options_(nullptr) {} | |||
| AclModel() : model_converter_(), options_(nullptr), options_str_() {} | |||
| ~AclModel() = default; | |||
| Status LoadModel(const Buffer &model_data, ModelType type, | |||
| const std::map<std::string, std::string> &options) override; | |||
| Status LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) override; | |||
| Status UnloadModel() override; | |||
| Status Build(const std::map<std::string, std::string> &options_map) override; | |||
| Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; | |||
| Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; | |||
| Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override; | |||
| Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override; | |||
| Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override; | |||
| private: | |||
| bool init_flag_; | |||
| bool load_flag_; | |||
| std::string device_type_; | |||
| int32_t device_id_; | |||
| aclrtContext context_; | |||
| aclrtStream stream_; | |||
| class AclEnvGuard; | |||
| std::shared_ptr<AclEnvGuard> acl_env_; | |||
| static std::weak_ptr<AclEnvGuard> global_acl_env_; | |||
| static std::mutex global_acl_env_mutex_; | |||
| static std::string GenerateOptionsStr(const std::map<std::string, std::string> &options); | |||
| ModelProcess model_process_; | |||
| DvppProcess dvpp_process_; | |||
| std::shared_ptr<GraphCell> graph_cell_; | |||
| ModelConverter model_converter_; | |||
| std::unique_ptr<AclModelOptions> options_; | |||
| Status InitEnv(); | |||
| Status FinalizeEnv(); | |||
| }; | |||
| class AclModel::AclEnvGuard { | |||
| public: | |||
| explicit AclEnvGuard(const std::string &cfg_file); | |||
| ~AclEnvGuard(); | |||
| aclError GetErrno() const { return errno_; } | |||
| private: | |||
| aclError errno_; | |||
| std::string options_str_; | |||
| }; | |||
| API_REG_MODEL(AscendCL, AclModel); | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H | |||
| @@ -1,160 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H | |||
| #define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include "acl/acl.h" | |||
| #include "acl/acl_mdl.h" | |||
| #include "acl/acl_rt.h" | |||
| #include "acl/ops/acl_dvpp.h" | |||
| #include "include/api/status.h" | |||
| namespace mindspore::api { | |||
| struct DvppDecodePara { | |||
| acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; | |||
| }; | |||
| struct DvppResizePara { | |||
| uint32_t output_width = 0; | |||
| uint32_t output_height = 0; | |||
| }; | |||
| enum DvppCropType { | |||
| // crop left,top,right,bottom is given in config | |||
| kDvppCropTypeOffset = 0, | |||
| // crop left,top,right,bottom is calculated by image width/height and output crop width/height | |||
| kDvppCropTypeCentre = 1, | |||
| }; | |||
| struct DvppRoiArea { | |||
| uint32_t left = 0; | |||
| uint32_t top = 0; | |||
| uint32_t right = 0; | |||
| uint32_t bottom = 0; | |||
| }; | |||
| struct DvppCropInfo { | |||
| DvppCropType crop_type = kDvppCropTypeOffset; | |||
| DvppRoiArea crop_area; // when kDvppCropTypeOffset | |||
| uint32_t crop_width = 0; // when kDvppCropTypeCentre | |||
| uint32_t crop_height = 0; // when kDvppCropTypeCentre | |||
| }; | |||
| struct DvppCropPara { | |||
| DvppCropInfo crop_info; | |||
| uint32_t output_width = 0; | |||
| uint32_t output_height = 0; | |||
| }; | |||
| struct DvppCropAndPastePara { | |||
| DvppCropInfo crop_info; | |||
| DvppRoiArea paste_area; | |||
| uint32_t output_width = 0; | |||
| uint32_t output_height = 0; | |||
| }; | |||
| class DvppProcess { | |||
| public: | |||
| DvppProcess(); | |||
| ~DvppProcess(); | |||
| Status InitResource(aclrtStream stream); | |||
| void Finalize(); | |||
| Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop) | |||
| Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize | |||
| Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop | |||
| Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste | |||
| Status InitWithJsonConfig(const std::string &json_config); | |||
| // output device buffer will be destroy by DvppProcess itself. | |||
| Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size); | |||
| Status Process(const std::vector<const void *> &pic_buffer_list, const std::vector<size_t> &pic_buffer_size_list, | |||
| void **output_device_buffer, size_t *output_size); | |||
| bool HasLoaded() const { return loaded_flag_; } | |||
| private: | |||
| bool loaded_flag_ = false; | |||
| uint32_t pic_width_ = 0; | |||
| uint32_t pic_height_ = 0; | |||
| DvppDecodePara decode_para_; | |||
| DvppResizePara resize_para_; | |||
| DvppCropPara crop_para_; | |||
| DvppCropAndPastePara crop_and_paste_para_; | |||
| // only one of the resize or crop flag can be true | |||
| bool to_resize_flag_ = false; | |||
| bool to_crop_flag_ = false; | |||
| bool to_crop_and_paste_flag_ = false; | |||
| void *input_pic_dev_buffer_ = nullptr; | |||
| uint32_t input_pic_buffer_size_ = 0; | |||
| uint32_t decode_output_buffer_size_ = 0; | |||
| void *decode_output_buffer_dev_ = nullptr; | |||
| acldvppPicDesc *decode_output_desc_ = nullptr; | |||
| acldvppResizeConfig *resize_config_ = nullptr; | |||
| acldvppRoiConfig *crop_area_ = nullptr; | |||
| acldvppRoiConfig *paste_area_ = nullptr; | |||
| acldvppPicDesc *vpc_output_desc_ = nullptr; | |||
| void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length | |||
| uint32_t vpc_output_buffer_size_ = 0; | |||
| void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length | |||
| uint32_t batch_size_ = 0; | |||
| aclrtStream stream_ = nullptr; | |||
| acldvppChannelDesc *dvpp_channel_desc_ = nullptr; | |||
| uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const; | |||
| uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const; | |||
| Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height); | |||
| Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height); | |||
| Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size); | |||
| Status InitDecodeOutputDesc(uint32_t image_width, | |||
| uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_ | |||
| Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height); | |||
| Status CheckAndAdjustRoiArea(DvppRoiArea *area); | |||
| Status UpdateCropArea(uint32_t image_width, uint32_t image_height); | |||
| Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const; | |||
| void DestroyDecodeDesc(); | |||
| Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height, | |||
| acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_ | |||
| Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area); | |||
| Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info); | |||
| Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config | |||
| Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_ | |||
| Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_ | |||
| void DestroyVpcOutputDesc(); | |||
| Status ProcessDecode(); | |||
| Status ProcessResize(); | |||
| Status ProcessCrop(); | |||
| Status ProcessCropAndPaste(); | |||
| void DestroyResource(); | |||
| Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width, | |||
| uint32_t *image_height); | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H | |||
| @@ -16,17 +16,13 @@ | |||
| #include "cxx_api/model/acl/model_converter.h" | |||
| #include <memory> | |||
| #include "pybind11/pybind11.h" | |||
| #include "transform/graph_ir/convert.h" | |||
| #include "transform/graph_ir/graph_runner.h" | |||
| #include "core/load_mindir/load_model.h" | |||
| #include "mindspore/core/utils/ms_context.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| #include "include/api/serialization.h" | |||
| #include "graph/model.h" | |||
| #include "cxx_api/model/model_converter_utils/multi_process.h" | |||
| namespace py = pybind11; | |||
| #include "cxx_api/python_utils.h" | |||
| namespace mindspore::api { | |||
| namespace { | |||
| @@ -74,19 +70,8 @@ bool CreateSessionAndGraphRunner() { | |||
| return true; | |||
| } | |||
| } // namespace | |||
| std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) { | |||
| try { | |||
| auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||
| return anf_graph; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Load MindIR failed."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| transform::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph) { | |||
| for (auto &anf_node : anf_graph->parameters()) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| @@ -166,88 +151,31 @@ Buffer ModelConverter::BuildAirModel(const transform::DfGraphPtr &graph, | |||
| return Buffer(model.data.get(), model.length); | |||
| } | |||
| void ModelConverter::RegAllOp() { | |||
| static std::mutex init_mutex; | |||
| static bool Initialized = false; | |||
| std::lock_guard<std::mutex> lock(init_mutex); | |||
| if (Initialized) { | |||
| return; | |||
| } | |||
| Initialized = true; | |||
| MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| Py_Initialize(); | |||
| auto c_expression = PyImport_ImportModule("mindspore._c_expression"); | |||
| MS_EXCEPTION_IF_NULL(c_expression); | |||
| PyObject *c_expression_dict = PyModule_GetDict(c_expression); | |||
| MS_EXCEPTION_IF_NULL(c_expression_dict); | |||
| PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); | |||
| MS_EXCEPTION_IF_NULL(op_info_loader_class); | |||
| PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); | |||
| MS_EXCEPTION_IF_NULL(op_info_loader); | |||
| PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); | |||
| MS_EXCEPTION_IF_NULL(op_info_loader_ins); | |||
| auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); | |||
| MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); | |||
| auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); | |||
| auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr); | |||
| for (auto op_info : *all_ops_info) { | |||
| kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info)); | |||
| } | |||
| all_ops_info->clear(); | |||
| delete all_ops_info; | |||
| Py_DECREF(op_info_loader); | |||
| Py_DECREF(op_info_loader_class); | |||
| Py_DECREF(c_expression_dict); | |||
| Py_DECREF(c_expression); | |||
| } | |||
| Buffer ModelConverter::ReadFile(const std::string &file) { | |||
| Buffer buffer; | |||
| if (file.empty()) { | |||
| MS_LOG(ERROR) << "Pointer file is nullptr"; | |||
| return buffer; | |||
| } | |||
| std::string realPath = file; | |||
| std::ifstream ifs(realPath); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "File: " << realPath << " is not exist"; | |||
| return buffer; | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "File: " << realPath << "open failed"; | |||
| return buffer; | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| size_t size = ifs.tellg(); | |||
| buffer.ResizeData(size); | |||
| if (buffer.DataSize() != size) { | |||
| MS_LOG(ERROR) << "Malloc buf failed, file: " << realPath; | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size); | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { | |||
| if (Py_IsInitialized() == 0) { | |||
| Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) { | |||
| if (!PythonIsInited()) { | |||
| MS_LOG_INFO << "Call LoadMindIRInner directly"; | |||
| return LoadMindIRInner(model_data); | |||
| return LoadMindIRInner(func_graph); | |||
| } | |||
| MultiProcess multi_process; | |||
| Buffer buffer_ret; | |||
| auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { | |||
| auto parent_process = [&func_graph, &buffer_ret, this](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| auto df_graph = ConvertFuncGraphToAIR(func_graph); | |||
| if (df_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; | |||
| return FAILED; | |||
| } | |||
| ge::Model model; | |||
| ge::Buffer model_data; | |||
| model.SetGraph(*df_graph); | |||
| auto ge_ret = model.Save(model_data); | |||
| if (ge_ret != ge::SUCCESS) { | |||
| MS_LOG(ERROR) << "Save ge model to buffer failed."; | |||
| return FAILED; | |||
| } | |||
| // send original model to child | |||
| auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); | |||
| auto status = multi_process->SendMsg(model_data.data(), model_data.size()); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Send original model to child process failed"; | |||
| return FAILED; | |||
| @@ -277,7 +205,7 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { | |||
| MS_LOG_ERROR << "Receive original model from parent process failed"; | |||
| return FAILED; | |||
| } | |||
| Buffer model_result = LoadMindIRInner(model); | |||
| Buffer model_result = LoadAscendIRInner(model); | |||
| if (model_result.DataSize() == 0) { | |||
| MS_LOG_ERROR << "Convert model from MindIR to OM failed"; | |||
| return FAILED; | |||
| @@ -300,7 +228,7 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { | |||
| } | |||
| Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { | |||
| if (Py_IsInitialized() == 0) { | |||
| if (!PythonIsInited()) { | |||
| MS_LOG_INFO << "Call LoadAscendIRInner directly"; | |||
| return LoadAscendIRInner(model_data); | |||
| } | |||
| @@ -361,10 +289,8 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { | |||
| return buffer_ret; | |||
| } | |||
| Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { | |||
| RegAllOp(); | |||
| Py_Initialize(); | |||
| auto func_graph = ConvertMindIrToFuncGraph(model_data); | |||
| Buffer ModelConverter::LoadMindIRInner(const FuncGraphPtr &func_graph) { | |||
| RegAllOpFromPython(); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; | |||
| return Buffer(); | |||
| @@ -386,7 +312,7 @@ Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { | |||
| } | |||
| Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) { | |||
| RegAllOp(); | |||
| RegAllOpFromPython(); | |||
| ge::Model load_model = ge::Model("loadmodel", "version2"); | |||
| ge::Status ret = | |||
| ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.Data()), model_data.DataSize(), load_model); | |||
| @@ -32,21 +32,17 @@ class ModelConverter { | |||
| public: | |||
| ModelConverter() : options_(nullptr) {} | |||
| Buffer LoadMindIR(const Buffer &model_data); | |||
| Buffer LoadMindIR(const FuncGraphPtr &func_graph); | |||
| Buffer LoadAscendIR(const Buffer &model_data); | |||
| void set_options(AclModelOptions *options) { options_ = options; } | |||
| static Buffer ReadFile(const std::string &file); | |||
| static void RegAllOp(); | |||
| private: | |||
| std::shared_ptr<FuncGraph> ConvertMindIrToFuncGraph(const Buffer &model_data); | |||
| transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); | |||
| Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &acl_options); | |||
| AclModelOptions *options_; | |||
| Buffer LoadMindIRInner(const Buffer &model_data); | |||
| Buffer LoadMindIRInner(const FuncGraphPtr &func_graph); | |||
| Buffer LoadAscendIRInner(const Buffer &model_data); | |||
| }; | |||
| } // namespace mindspore::api | |||
| @@ -14,93 +14,59 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/model.h" | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore::api { | |||
| const char *kDeviceTypeAscendCL = "AscendCL"; | |||
| const char *kDeviceTypeAscendMS = "AscendMS"; | |||
| Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) { | |||
| Status Model::Build(const std::map<std::string, std::string> &options) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->LoadModel(model_data, type, options); | |||
| return impl_->Build(options); | |||
| } | |||
| Status Model::LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->LoadModel(file_name, type, options); | |||
| } | |||
| Status Model::UnloadModel() { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->UnloadModel(); | |||
| } | |||
| Status Model::Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) { | |||
| Status Model::Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->Train(dataset, outputs); | |||
| } | |||
| Status Model::Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) { | |||
| Status Model::Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->Eval(dataset, outputs); | |||
| } | |||
| Status Model::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) { | |||
| Status Model::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->Predict(inputs, outputs); | |||
| } | |||
| Status Model::Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs) { | |||
| std::vector<Tensor> tensor_list; | |||
| auto ret = GetInputsInfo(&tensor_list); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "GetInputsInfo failed."; | |||
| return ret; | |||
| } | |||
| if (inputs.size() != tensor_list.size()) { | |||
| MS_LOG(ERROR) << "Model need " << tensor_list.size() << " inputs, but given " << inputs.size(); | |||
| return FAILED; | |||
| } | |||
| std::map<std::string, Buffer> inputs_with_map; | |||
| for (size_t i = 0; i < tensor_list.size(); ++i) { | |||
| inputs_with_map.emplace(tensor_list[i].Name(), inputs[i]); | |||
| } | |||
| return Predict(inputs_with_map, outputs); | |||
| } | |||
| Status Model::GetInputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| Status Model::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->GetInputsInfo(tensor_list); | |||
| return impl_->GetInputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| Status Model::GetOutputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| Status Model::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->GetOutputsInfo(tensor_list); | |||
| return impl_->GetOutputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| Model::Model(const std::string &device_type, uint32_t device_id) | |||
| : impl_(ModelFactory::Instance().Create(device_type, device_id)) { | |||
| Model::Model(const GraphCell &graph_cell) | |||
| : impl_(Factory<ModelImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) { | |||
| if (impl_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed"; | |||
| MS_LOG(EXCEPTION) << "Create session type " << Context::Instance().GetDeviceTarget() << " failed"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph_cell.GetGraph()); | |||
| impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph())); | |||
| } | |||
| Model::Model(NetWork network, const std::string &device_type, uint32_t device_id) { | |||
| // todo | |||
| if (impl_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Create session type " << device_type << " failed"; | |||
| } | |||
| } | |||
| Model::Model(const std::vector<Output> &network) { MS_LOG(EXCEPTION) << "Unsupported feature."; } | |||
| Model::~Model() {} | |||
| bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { | |||
| return ModelFactory::Instance().CheckModelSupport(device_type, model_type); | |||
| bool Model::CheckModelSupport(const std::string &device_type, ModelType) { | |||
| return Factory<ModelImpl>::Instance().CheckModelSupport(device_type); | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -22,7 +22,10 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "include/api/model.h" | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/utils.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::api { | |||
| class ModelImpl { | |||
| @@ -30,70 +33,39 @@ class ModelImpl { | |||
| ModelImpl() = default; | |||
| virtual ~ModelImpl() = default; | |||
| virtual Status LoadModel(const Buffer &model_data, ModelType type, | |||
| const std::map<std::string, std::string> &options) = 0; | |||
| virtual Status LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) = 0; | |||
| virtual Status UnloadModel() = 0; | |||
| virtual Status Build(const std::map<std::string, std::string> &options) = 0; | |||
| virtual Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0; | |||
| virtual Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) = 0; | |||
| virtual Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) = 0; | |||
| virtual Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0; | |||
| virtual Status GetInputsInfo(std::vector<Tensor> *tensor_list) const = 0; | |||
| virtual Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const = 0; | |||
| }; | |||
| using ModelCreator = std::function<std::shared_ptr<ModelImpl>(uint32_t device_id)>; | |||
| class ModelFactory { | |||
| public: | |||
| ModelFactory(const ModelFactory &) = delete; | |||
| void operator=(const ModelFactory &) = delete; | |||
| virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0; | |||
| virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const = 0; | |||
| static ModelFactory &Instance() { | |||
| static ModelFactory instance; | |||
| return instance; | |||
| protected: | |||
| Status Load(const std::shared_ptr<GraphCell> &graph_cell) { | |||
| MS_EXCEPTION_IF_NULL(graph_cell); | |||
| return graph_cell->Load(); | |||
| } | |||
| void Register(const std::string &device_name, ModelCreator &&model_creator) { | |||
| if (model_creators_.find(device_name) == model_creators_.end()) { | |||
| (void)model_creators_.emplace(device_name, model_creator); | |||
| FuncGraphPtr GetFuncGraph() const { | |||
| if (graph_->ModelType() != ModelType::kMindIR) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::shared_ptr<ModelImpl> Create(const std::string &device_name, uint32_t device_id) { | |||
| auto iter = model_creators_.find(device_name); | |||
| if (model_creators_.end() != iter) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| return (iter->second)(device_id); | |||
| } | |||
| return nullptr; | |||
| auto graph_data = graph_->graph_data_; | |||
| MS_EXCEPTION_IF_NULL(graph_data); | |||
| return graph_data->GetFuncGraph(); | |||
| } | |||
| bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) { | |||
| return std::any_of( | |||
| model_creators_.begin(), model_creators_.end(), | |||
| [&device_type](const std::pair<std::string, ModelCreator> &item) { return item.first == device_type; }); | |||
| } | |||
| std::shared_ptr<Graph> graph_; | |||
| private: | |||
| ModelFactory() = default; | |||
| ~ModelFactory() = default; | |||
| std::map<std::string, ModelCreator> model_creators_; | |||
| friend class Model; | |||
| void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | |||
| }; | |||
| class ModelRegistrar { | |||
| public: | |||
| ModelRegistrar(const std::string &device_name, ModelCreator model_creator) { | |||
| ModelFactory::Instance().Register(device_name, std::move(model_creator)); | |||
| } | |||
| ~ModelRegistrar() = default; | |||
| }; | |||
| #define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \ | |||
| static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \ | |||
| kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); }); | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H | |||
| @@ -16,164 +16,33 @@ | |||
| #include "cxx_api/model/ms/ms_model.h" | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include "load_mindir/load_model.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "backend/session/executor_manager.h" | |||
| #include "base/base_ref_utils.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| #include "utils/context/context_extends.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/embed.h" | |||
| #ifdef ENABLE_D | |||
| #include "utils/ms_context.h" | |||
| #endif | |||
| using std::string; | |||
| using std::vector; | |||
| #include "cxx_api/factory.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace api { | |||
| MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {} | |||
| MsModel::~MsModel() = default; | |||
| API_FACTORY_REG(ModelImpl, Ascend910, MsModel); | |||
| TypeId TransInferDataType2TypeId(DataType data_type) { | |||
| const std::map<api::DataType, TypeId> type2id_map{ | |||
| {api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool}, | |||
| {api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8}, | |||
| {api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16}, | |||
| {api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32}, | |||
| {api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64}, | |||
| {api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32}, | |||
| {api::kMsFloat64, TypeId::kNumberTypeFloat64}, | |||
| }; | |||
| auto it = type2id_map.find(data_type); | |||
| if (it == type2id_map.end()) { | |||
| MS_LOG_WARNING << "Unsupported MSI data type " << data_type; | |||
| return TypeId::kNumberTypeBegin; | |||
| } else { | |||
| return it->second; | |||
| } | |||
| } | |||
| Status MsModel::Build(const std::map<std::string, std::string> &) { | |||
| MS_LOG(INFO) << "Start build model."; | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| DataType TransTypeId2InferDataType(TypeId type_id) { | |||
| const std::map<TypeId, api::DataType> id2type_map{ | |||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||
| }; | |||
| auto it = id2type_map.find(type_id); | |||
| if (it == id2type_map.end()) { | |||
| MS_LOG_WARNING << "Unsupported data id " << type_id; | |||
| return api::kMsUnknown; | |||
| } else { | |||
| return it->second; | |||
| } | |||
| } | |||
| Buffer MsModel::ReadFile(const std::string &file) { | |||
| if (file.empty()) { | |||
| MS_LOG(ERROR) << "file is nullptr"; | |||
| return Buffer(); | |||
| } | |||
| std::ifstream ifs(file); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "file: " << file << " is not exist"; | |||
| return Buffer(); | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "file: " << file << "open failed"; | |||
| return Buffer(); | |||
| } | |||
| auto func_graph = ModelImpl::GetFuncGraph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| ifs.seekg(0, std::ios::end); | |||
| size_t size = ifs.tellg(); | |||
| Buffer buffer; | |||
| buffer.ResizeData(size); | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(static_cast<char *>(buffer.MutableData()), size); | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) { | |||
| auto status = InitEnv({}); | |||
| if (status != SUCCESS) { | |||
| MS_LOG(ERROR) << "Init env failed"; | |||
| return FAILED; | |||
| } | |||
| std::shared_ptr<FuncGraph> anf_graph; | |||
| Py_Initialize(); | |||
| try { | |||
| anf_graph = ConvertStreamToFuncGraph(static_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference LoadModel failed"; | |||
| return FAILED; | |||
| } | |||
| Status ret = CompileGraph(anf_graph); | |||
| auto graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(func_graph, ModelType::kMindIR)); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto graph_cell = std::make_shared<GraphCell>(graph); | |||
| MS_EXCEPTION_IF_NULL(graph_cell); | |||
| auto ret = ModelImpl::Load(graph_cell); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Compile graph model failed"; | |||
| return FAILED; | |||
| } | |||
| session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); | |||
| session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); | |||
| if (inputs_.empty() || inputs_.size() != input_names_.size()) { | |||
| MS_LOG_ERROR << "Get model inputs info failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.empty() || outputs_.size() != output_names_.size()) { | |||
| MS_LOG_ERROR << "Get model outputs info failed"; | |||
| return FAILED; | |||
| MS_LOG(ERROR) << "Load failed."; | |||
| return ret; | |||
| } | |||
| MS_LOG(INFO) << "Load model success"; | |||
| #ifdef ENABLE_D | |||
| // set d context | |||
| rtError_t rt_ret = rtCtxGetCurrent(&context_); | |||
| if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { | |||
| MS_LOG(ERROR) << "the ascend device context is null"; | |||
| return FAILED; | |||
| } | |||
| #endif | |||
| load_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) { | |||
| auto graphBuf = ReadFile(file_name); | |||
| if (graphBuf.DataSize() == 0) { | |||
| MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| } | |||
| auto status = LoadModel(graphBuf, type, options); | |||
| if (status != SUCCESS) { | |||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::UnloadModel() { | |||
| if (!load_flag_) { | |||
| MS_LOG_ERROR << "Model has not been loaded"; | |||
| return FAILED; | |||
| } | |||
| FinalizeEnv(); | |||
| load_flag_ = false; | |||
| // save result | |||
| graph_cell_ = graph_cell; | |||
| MS_LOG(INFO) << "Build model success."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -187,231 +56,42 @@ Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) { | |||
| return FAILED; | |||
| } | |||
| Status MsModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) { | |||
| Status MsModel::Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!load_flag_) { | |||
| MS_LOG(ERROR) << "No model is loaded, predict failed."; | |||
| if (graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid data, graph_ is null."; | |||
| return FAILED; | |||
| } | |||
| if (inputs.size() != inputs_.size()) { | |||
| MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| std::vector<Buffer> request; | |||
| std::vector<Buffer> reply; | |||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||
| const auto &input_name = input_names_[i]; | |||
| auto iter = inputs.find(input_name); | |||
| if (iter == inputs.end()) { | |||
| MS_LOG(ERROR) << "Model missing input " << input_name; | |||
| return INVALID_INPUTS; | |||
| } | |||
| if (iter->second.DataSize() != inputs_[i]->Size()) { | |||
| MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " | |||
| << iter->second.DataSize(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| request.push_back(iter->second); | |||
| } | |||
| if (ExecuteModel(request, &reply) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.size() != reply.size()) { | |||
| MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info " | |||
| << outputs_.size(); | |||
| return FAILED; | |||
| } | |||
| outputs->clear(); | |||
| for (size_t i = 0; i < reply.size(); i++) { | |||
| outputs->emplace(output_names_[i], reply[i]); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||
| MS_EXCEPTION_IF_NULL(reply); | |||
| #ifdef ENABLE_D | |||
| if (context_ == nullptr) { | |||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | |||
| return FAILED; | |||
| } | |||
| rtError_t rt_ret = rtCtxSetCurrent(context_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "set Ascend rtCtx failed"; | |||
| return FAILED; | |||
| } | |||
| #endif | |||
| vector<tensor::TensorPtr> inputs; | |||
| for (size_t i = 0; i < request.size(); i++) { | |||
| auto &item = request[i]; | |||
| auto input = inputs_[i]; | |||
| if (input->Size() != item.DataSize()) { | |||
| MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size " | |||
| << input->Size(); | |||
| return FAILED; | |||
| } | |||
| auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); | |||
| if (graph_cell_ == nullptr) { | |||
| MS_LOG(INFO) << "Model has not been built, it will be built with default options"; | |||
| Status ret = Build({}); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Tensor copy failed"; | |||
| MS_LOG(ERROR) << "Build model failed."; | |||
| return FAILED; | |||
| } | |||
| inputs.push_back(input); | |||
| } | |||
| vector<tensor::TensorPtr> outputs = RunGraph(inputs); | |||
| if (outputs.empty()) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| reply->clear(); | |||
| std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), | |||
| [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::FinalizeEnv() { | |||
| MS_LOG_INFO << "Start finalize env"; | |||
| py::gil_scoped_acquire acquire; | |||
| session::ExecutorManager::Instance().Clear(); | |||
| device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| if (!context::CloseTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "Inference CloseTsd failed!"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG_INFO << "End finalize env"; | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { | |||
| Py_Initialize(); | |||
| MS_EXCEPTION_IF_NULL(model_buf); | |||
| try { | |||
| auto anf_graph = ConvertStreamToFuncGraph(model_buf, size); | |||
| return anf_graph; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); | |||
| return nullptr; | |||
| } | |||
| } | |||
| void MsModel::RegAllOp() { | |||
| static std::mutex init_mutex; | |||
| static bool Initialized = false; | |||
| std::lock_guard<std::mutex> lock(init_mutex); | |||
| if (Initialized) { | |||
| return; | |||
| } | |||
| Initialized = true; | |||
| auto ms_context_instance = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context_instance); | |||
| ms_context_instance->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| try { | |||
| std::shared_ptr<py::scoped_interpreter> guard; | |||
| if (Py_IsInitialized() == 0) { | |||
| guard = std::make_shared<py::scoped_interpreter>(); | |||
| } | |||
| py::module c_expression = py::module::import("mindspore._c_expression"); | |||
| size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast<size_t>(); | |||
| auto all_ops_info = reinterpret_cast<std::vector<kernel::OpInfo *> *>(static_cast<uintptr_t>(ops_info_long)); | |||
| for (auto op_info : *all_ops_info) { | |||
| kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info)); | |||
| } | |||
| all_ops_info->clear(); | |||
| delete all_ops_info; | |||
| } catch (const std::runtime_error &ex) { | |||
| MS_LOG_EXCEPTION << ex.what(); | |||
| } | |||
| } | |||
| Status MsModel::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| py::gil_scoped_release gil_release; | |||
| return SUCCESS; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what(); | |||
| return FAILED; | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> MsModel::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||
| try { | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what(); | |||
| return std::vector<tensor::TensorPtr>(); | |||
| } | |||
| } | |||
| Status MsModel::InitEnv(const std::unordered_map<std::string, std::string> &other_options) { | |||
| RegAllOp(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | |||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice); | |||
| if (!context::OpenTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "Session init OpenTsd failed!"; | |||
| return FAILED; | |||
| } | |||
| session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice | |||
| << " is available."; | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| Status ret = graph_cell_->Run(inputs, outputs); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Run graph failed."; | |||
| return FAILED; | |||
| } | |||
| session_impl_->Init(device_id_); | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| std::string error_msg; | |||
| if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) { | |||
| return Status(INVALID_INPUTS, error_msg); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| tensor_list->clear(); | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto &tensor = inputs_[i]; | |||
| Tensor infer_tensor; | |||
| infer_tensor.SetName(input_names_[i]); | |||
| infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); | |||
| infer_tensor.SetShape(tensor->shape()); | |||
| tensor_list->push_back(infer_tensor); | |||
| } | |||
| return SUCCESS; | |||
| Status MsModel::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| return graph_cell_->GetInputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| Status MsModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| tensor_list->clear(); | |||
| for (size_t i = 0; i < outputs_.size(); i++) { | |||
| auto &tensor = outputs_[i]; | |||
| Tensor infer_tensor; | |||
| infer_tensor.SetName(output_names_[i]); | |||
| infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); | |||
| infer_tensor.SetShape(tensor->shape()); | |||
| tensor_list->push_back(infer_tensor); | |||
| } | |||
| return SUCCESS; | |||
| Status MsModel::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const { | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| return graph_cell_->GetOutputsInfo(names, shapes, data_types, mem_sizes); | |||
| } | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| @@ -36,49 +36,23 @@ namespace mindspore { | |||
| namespace api { | |||
| class MsModel : public ModelImpl { | |||
| public: | |||
| explicit MsModel(uint32_t device_id); | |||
| ~MsModel(); | |||
| MsModel() {} | |||
| ~MsModel() = default; | |||
| Status LoadModel(const Buffer &model_data, ModelType type, | |||
| const std::map<std::string, std::string> &options) override; | |||
| Status LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) override; | |||
| Status UnloadModel() override; | |||
| Status Build(const std::map<std::string, std::string> &options_map) override; | |||
| Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; | |||
| Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; | |||
| Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override; | |||
| Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override; | |||
| Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override; | |||
| Status InitEnv(const std::unordered_map<std::string, std::string> &other_options); | |||
| Status FinalizeEnv(); | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override; | |||
| private: | |||
| std::shared_ptr<session::SessionBasic> session_impl_ = nullptr; | |||
| uint32_t graph_id_; | |||
| std::string device_type_; | |||
| int32_t device_id_ = 0; | |||
| #ifdef ENABLE_D | |||
| rtContext_t context_ = nullptr; | |||
| #endif | |||
| std::vector<tensor::TensorPtr> inputs_; | |||
| std::vector<tensor::TensorPtr> outputs_; | |||
| std::vector<std::string> input_names_; | |||
| std::vector<std::string> output_names_; | |||
| bool load_flag_ = false; | |||
| std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device); | |||
| Buffer ReadFile(const std::string &file); | |||
| static void RegAllOp(); | |||
| Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr); | |||
| Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const; | |||
| std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); | |||
| Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||
| std::shared_ptr<GraphCell> graph_cell_; | |||
| }; | |||
| API_REG_MODEL(AscendMS, MsModel); | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/python_utils.h" | |||
| #include <mutex> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "mindspore/core/utils/ms_context.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore::api { | |||
| void RegAllOpFromPython() { | |||
| static std::mutex init_mutex; | |||
| static bool Initialized = false; | |||
| std::lock_guard<std::mutex> lock(init_mutex); | |||
| if (Initialized) { | |||
| return; | |||
| } | |||
| Initialized = true; | |||
| MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| Py_Initialize(); | |||
| auto c_expression = PyImport_ImportModule("mindspore._c_expression"); | |||
| MS_EXCEPTION_IF_NULL(c_expression); | |||
| PyObject *c_expression_dict = PyModule_GetDict(c_expression); | |||
| MS_EXCEPTION_IF_NULL(c_expression_dict); | |||
| PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); | |||
| MS_EXCEPTION_IF_NULL(op_info_loader_class); | |||
| PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); | |||
| MS_EXCEPTION_IF_NULL(op_info_loader); | |||
| PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); | |||
| MS_EXCEPTION_IF_NULL(op_info_loader_ins); | |||
| auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); | |||
| MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); | |||
| auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); | |||
| auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr); | |||
| for (auto op_info : *all_ops_info) { | |||
| kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info)); | |||
| } | |||
| all_ops_info->clear(); | |||
| delete all_ops_info; | |||
| Py_DECREF(op_info_loader); | |||
| Py_DECREF(op_info_loader_class); | |||
| Py_DECREF(c_expression_dict); | |||
| Py_DECREF(c_expression); | |||
| } | |||
| bool PythonIsInited() { return Py_IsInitialized() != 0; } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H | |||
| #define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H | |||
| #include "pybind11/pybind11.h" | |||
| namespace mindspore::api { | |||
| void RegAllOpFromPython(); | |||
| bool PythonIsInited(); | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H | |||
| @@ -14,9 +14,77 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/serialization.h" | |||
| #include <fstream> | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "mindspore/core/load_mindir/load_model.h" | |||
| namespace mindspore::api { | |||
| static Buffer ReadFile(const std::string &file) { | |||
| Buffer buffer; | |||
| if (file.empty()) { | |||
| MS_LOG(ERROR) << "Pointer file is nullptr"; | |||
| return buffer; | |||
| } | |||
| char real_path_mem[PATH_MAX] = {0}; | |||
| char *real_path_ret = nullptr; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX); | |||
| #else | |||
| real_path_ret = realpath(common::SafeCStr(file), real_path_mem); | |||
| #endif | |||
| if (real_path_ret == nullptr) { | |||
| MS_LOG(ERROR) << "File: " << file << " is not exist."; | |||
| return buffer; | |||
| } | |||
| std::string real_path(real_path_mem); | |||
| std::ifstream ifs(real_path); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "File: " << real_path << " is not exist"; | |||
| return buffer; | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "File: " << real_path << "open failed"; | |||
| return buffer; | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| size_t size = ifs.tellg(); | |||
| buffer.ResizeData(size); | |||
| if (buffer.DataSize() != size) { | |||
| MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path; | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size); | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { | |||
| Buffer data = ReadFile(file); | |||
| if (model_type == kMindIR) { | |||
| FuncGraphPtr anf_graph = nullptr; | |||
| try { | |||
| anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize()); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Load MindIR failed."; | |||
| } | |||
| return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); | |||
| } else if (model_type == kOM) { | |||
| return Graph(std::make_shared<Graph::GraphData>(data, kOM)); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; | |||
| } | |||
| Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return FAILED; | |||
| @@ -19,6 +19,9 @@ | |||
| #include "utils/utils.h" | |||
| namespace mindspore::api { | |||
| const char *kDeviceTypeAscend310 = "Ascend310"; | |||
| const char *kDeviceTypeAscend910 = "Ascend910"; | |||
| class DataImpl { | |||
| public: | |||
| DataImpl() : data_() {} | |||
| @@ -422,7 +422,6 @@ inline ValuePtr MakeValue(S v) { | |||
| template <typename S, typename U = typename ImmTraits<S>::type> | |||
| static S GetValue(const ValuePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| U imm = value->cast<U>(); | |||
| if (imm == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name(); | |||
| @@ -1,5 +1,10 @@ | |||
| #add flags | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") | |||
| add_subdirectory("ut") | |||
| if (ENABLE_ACL) | |||
| add_subdirectory(cxx_st) | |||
| elseif (ENABLE_GPU OR ENABLE_D OR ENABLE_CPU) | |||
| message(fatal "No need set -e xxx when compile ut") | |||
| else () | |||
| add_subdirectory(ut) | |||
| endif() | |||
| @@ -0,0 +1,11 @@ | |||
| include_directories(${PYTHON_INCLUDE_DIRS}) | |||
| include_directories(${MS_CCSRC_PATH}) | |||
| include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/) | |||
| include_directories(${CMAKE_BINARY_DIR}) | |||
| include_directories(${CUDA_INCLUDE_DIRS}) | |||
| file(GLOB_RECURSE CXX_ST_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc) | |||
| add_executable(st_tests ${CXX_ST_SRC}) | |||
| target_link_libraries(st_tests PRIVATE mindspore_shared_lib mindspore::gtest) | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common_test.h" | |||
| #ifdef __cplusplus | |||
| #if __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| #endif | |||
| namespace ST { | |||
| void Common::SetUpTestCase() {} | |||
| void Common::TearDownTestCase() {} | |||
| void Common::SetUp() {} | |||
| void Common::TearDown() {} | |||
| } // namespace ST | |||
| #ifdef __cplusplus | |||
| #if __cplusplus | |||
| } | |||
| #endif | |||
| #endif | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef TESTS_CXX_ST_COMMON_COMMON_TEST_H_ | |||
| #define TESTS_CXX_ST_COMMON_COMMON_TEST_H_ | |||
| #include <cmath> | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include "gtest/gtest.h" | |||
| namespace ST { | |||
| class Common : public testing::Test { | |||
| public: | |||
| // TestCase only enter once | |||
| static void SetUpTestCase(); | |||
| static void TearDownTestCase(); | |||
| // every TEST_F macro will enter one | |||
| virtual void SetUp(); | |||
| virtual void TearDown(); | |||
| template <typename T> | |||
| void PrintData(std::string name, T *output_data, int size) { | |||
| std::cout << "The " << name << " is as follows:" << std::endl; | |||
| if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { | |||
| for (size_t i = 0; i < std::min(size, 100); i++) { | |||
| std::cout << (int)output_data[i] << " "; | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < std::min(size, 100); i++) { | |||
| std::cout << output_data[i] << " "; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| template <typename T> | |||
| static void CompareOutputData(T *output_data, T *correct_data, int size, float err_bound) { | |||
| for (size_t i = 0; i < size; i++) { | |||
| T abs = fabs(output_data[i] - correct_data[i]); | |||
| ASSERT_LE(abs, err_bound); | |||
| } | |||
| } | |||
| void ReadFile(const char *file, size_t *size, char **buf) { | |||
| ASSERT_NE(nullptr, file); | |||
| ASSERT_NE(nullptr, size); | |||
| ASSERT_NE(nullptr, buf); | |||
| std::string path = std::string(file); | |||
| std::ifstream ifs(path); | |||
| ASSERT_EQ(true, ifs.good()); | |||
| ASSERT_EQ(true, ifs.is_open()); | |||
| ifs.seekg(0, std::ios::end); | |||
| *size = ifs.tellg(); | |||
| *buf = new char[*size]; | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(*buf, *size); | |||
| ifs.close(); | |||
| } | |||
| }; | |||
| } // namespace ST | |||
| #endif // TESTS_CXX_ST_COMMON_COMMON_TEST_H_ | |||
| @@ -0,0 +1,22 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "gtest/gtest.h" | |||
| GTEST_API_ int main(int argc, char** argv) { | |||
| testing::InitGoogleTest(&argc, argv); | |||
| int ret = RUN_ALL_TESTS(); | |||
| return ret; | |||
| } | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "common/common_test.h" | |||
| #include "include/api/model.h" | |||
| #include "include/api/serialization.h" | |||
| #include "include/api/context.h" | |||
| using namespace mindspore::api; | |||
| static const char tensor_add_file[] = "/home/workspace/mindspore_dataset/tensor_add/tensor_add.mindir"; | |||
| static const std::vector<float> input_data_1 = {1, 2, 3, 4}; | |||
| static const std::vector<float> input_data_2 = {2, 3, 4, 5}; | |||
| class TestTensorAdd : public ST::Common { | |||
| public: | |||
| TestTensorAdd() {} | |||
| }; | |||
| TEST_F(TestTensorAdd, InferMindIR) { | |||
| Context::Instance().SetDeviceTarget(kDeviceTypeAscend310).SetDeviceID(1); | |||
| auto graph = Serialization::LoadModel(tensor_add_file, ModelType::kMindIR); | |||
| Model tensor_add((GraphCell(graph))); | |||
| Status ret = tensor_add.Build({}); | |||
| ASSERT_TRUE(ret == SUCCESS); | |||
| // prepare input | |||
| std::vector<Buffer> outputs; | |||
| std::vector<Buffer> inputs; | |||
| inputs.emplace_back(Buffer(input_data_1.data(), sizeof(float) * input_data_1.size())); | |||
| inputs.emplace_back(Buffer(input_data_2.data(), sizeof(float) * input_data_2.size())); | |||
| // infer | |||
| ret = tensor_add.Predict(inputs, &outputs); | |||
| ASSERT_TRUE(ret == SUCCESS); | |||
| for (auto &buffer : outputs) { | |||
| const float *p = reinterpret_cast<const float *>(buffer.Data()); | |||
| for (size_t i = 0; i < buffer.DataSize() / sizeof(float); ++i) { | |||
| ASSERT_LE(std::abs(p[i] - (input_data_1[i] + input_data_2[i])), 1e-4); | |||
| } | |||
| } | |||
| } | |||