1. namespace is mindspore::api; 2. API header files located in mindspore/core/mindapi; 3. We use pimpl pattern to provide a wrapper layer for api; 4. Check mindapi_test.cc for usage examples.tags/v1.6.0
| @@ -221,6 +221,8 @@ if(PLATFORM_ARM64) | |||
| endif() | |||
| install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE) | |||
| install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api | |||
| @@ -298,6 +300,8 @@ elseif(PLATFORM_ARM32) | |||
| endif() | |||
| install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE) | |||
| install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api | |||
| @@ -365,6 +369,8 @@ elseif(WIN32) | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE) | |||
| install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api | |||
| @@ -409,6 +415,8 @@ else() | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE) | |||
| install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api | |||
| @@ -430,6 +438,10 @@ else() | |||
| PATTERN "train*" EXCLUDE PATTERN "delegate.h" EXCLUDE PATTERN "lite_session.h" EXCLUDE) | |||
| install(FILES ${API_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/api | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${MINDAPI_BASE_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/base | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${MINDAPI_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/ir | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${ABSTRACT_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract | |||
| COMPONENT ${RUNTIME_COMPONENT_NAME}) | |||
| install(FILES ${API_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api/ir | |||
| @@ -23,6 +23,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "ir/*.cc" | |||
| "utils/*.cc" | |||
| "load_mindir/*.cc" | |||
| "mindapi/src/*.cc" | |||
| ) | |||
| if(ENABLE_SECURITY) | |||
| @@ -24,8 +24,7 @@ | |||
| #include "utils/visible.h" | |||
| #include "api/ir/func_graph_manager.h" | |||
| namespace mindspore::api { | |||
| namespace mindspore::deprecated::api { | |||
| /// \brief FuncGraph defines interface for a function graph. | |||
| class MS_CORE_API FuncGraph { | |||
| public: | |||
| @@ -147,5 +146,12 @@ class MS_CORE_API FuncGraph { | |||
| /// \return The function graph if the input is value node that holds the graph, nullptr otherwise. | |||
| static FuncGraphPtr GetFuncGraphFromAnfNode(const AnfNodePtr &input); | |||
| }; | |||
| } // namespace mindspore::api | |||
| #ifndef USE_DEPRECATED_API | |||
| #define USE_DEPRECATED_API | |||
| namespace mindspore { | |||
| namespace api = deprecated::api; | |||
| } | |||
| #endif | |||
| } // namespace mindspore::deprecated::api | |||
| #endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_H_ | |||
| @@ -26,8 +26,7 @@ | |||
| #include "utils/hashing.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore::api { | |||
| namespace mindspore::deprecated::api { | |||
| class FuncGraph; | |||
| using FuncGraphPtr = std::shared_ptr<FuncGraph>; | |||
| @@ -80,7 +79,13 @@ class MS_CORE_API FuncGraphManager { | |||
| /// \return The manager that manages the given function graph. | |||
| static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true); | |||
| }; | |||
| } // namespace mindspore::api | |||
| } // namespace mindspore::deprecated::api | |||
| #ifndef USE_DEPRECATED_API | |||
| #define USE_DEPRECATED_API | |||
| namespace mindspore { | |||
| namespace api = deprecated::api; | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_MANAGER_H_ | |||
| @@ -52,6 +52,7 @@ static const std::vector<std::string> sub_module_names = { | |||
| "HCCL_ADPT", // SM_HCCL_ADPT | |||
| "RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK | |||
| "GE", // SM_GE | |||
| "API", // SM_API | |||
| }; | |||
| const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_names[(module_id % NUM_SUBMODUES)]; } | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -19,87 +19,6 @@ | |||
| #ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ | |||
| #define MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ | |||
| namespace mindspore { | |||
| // | |||
| // Supported meta type | |||
| // | |||
| enum TypeId : int { | |||
| kTypeUnknown = 0, | |||
| kMetaTypeBegin = kTypeUnknown, | |||
| kMetaTypeType, // Type | |||
| kMetaTypeAnything, | |||
| kMetaTypeObject, | |||
| kMetaTypeTypeType, // TypeType | |||
| kMetaTypeProblem, | |||
| kMetaTypeExternal, | |||
| kMetaTypeNone, | |||
| kMetaTypeNull, | |||
| kMetaTypeEllipsis, | |||
| kMetaTypeEnd, | |||
| // | |||
| // Object types | |||
| // | |||
| kObjectTypeBegin = kMetaTypeEnd, | |||
| kObjectTypeNumber, | |||
| kObjectTypeString, | |||
| kObjectTypeList, | |||
| kObjectTypeTuple, | |||
| kObjectTypeSlice, | |||
| kObjectTypeKeyword, | |||
| kObjectTypeTensorType, | |||
| kObjectTypeRowTensorType, | |||
| kObjectTypeSparseTensorType, | |||
| kObjectTypeUndeterminedType, | |||
| kObjectTypeClass, | |||
| kObjectTypeDictionary, | |||
| kObjectTypeFunction, | |||
| kObjectTypeJTagged, | |||
| kObjectTypeSymbolicKeyType, | |||
| kObjectTypeEnvType, | |||
| kObjectTypeRefKey, | |||
| kObjectTypeRef, | |||
| kObjectTypeEnd, | |||
| // | |||
| // Number Types | |||
| // | |||
| kNumberTypeBegin = kObjectTypeEnd, | |||
| kNumberTypeBool, | |||
| kNumberTypeInt, | |||
| kNumberTypeInt8, | |||
| kNumberTypeInt16, | |||
| kNumberTypeInt32, | |||
| kNumberTypeInt64, | |||
| kNumberTypeUInt, | |||
| kNumberTypeUInt8, | |||
| kNumberTypeUInt16, | |||
| kNumberTypeUInt32, | |||
| kNumberTypeUInt64, | |||
| kNumberTypeFloat, | |||
| kNumberTypeFloat16, | |||
| kNumberTypeFloat32, | |||
| kNumberTypeFloat64, | |||
| kNumberTypeComplex, | |||
| kNumberTypeComplex64, | |||
| kNumberTypeComplex128, | |||
| kNumberTypeInt4, | |||
| kNumberTypeGLUInt, | |||
| kNumberTypeEnd, | |||
| // | |||
| // Monad Types | |||
| // | |||
| kMonadTypeBegin = kNumberTypeEnd, | |||
| kObjectTypeMonad, | |||
| kObjectTypeUMonad, | |||
| kObjectTypeIOMonad, | |||
| kMonadTypeEnd, | |||
| // | |||
| // Sparse Types | |||
| // | |||
| // Sparse types is placed at the end of enum, | |||
| // in order to keep fit with the type of existing model on the lite side. | |||
| kSparseTypeBegin = kMonadTypeEnd, | |||
| kObjectTypeCSRTensorType, | |||
| kSparseTypeEnd | |||
| }; | |||
| } // namespace mindspore | |||
| #include "mindapi/base/type_id.h" | |||
| #endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ | |||
| @@ -153,7 +153,7 @@ class FuncGraphBase : public Value { | |||
| MS_DECLARE_PARENT(FuncGraphBase, Value); | |||
| }; | |||
| class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder { | |||
| class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, public EffectInfoHolder { | |||
| public: | |||
| using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>; | |||
| @@ -265,7 +265,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo | |||
| FuncGraphManagerPtr manager() const { return manager_.lock(); } | |||
| void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); } | |||
| api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); } | |||
| deprecated::api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); } | |||
| std::string ToString() const override; | |||
| GraphDebugInfoPtr debug_info(); | |||
| @@ -55,9 +55,9 @@ class FuncGraphTransaction; | |||
| class FuncGraphManager; | |||
| using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; | |||
| using AnfNodeIndexSet = api::AnfNodeIndexSet; | |||
| using AnfNodeIndexSet = deprecated::api::AnfNodeIndexSet; | |||
| // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i) | |||
| using NodeUsersMap = api::NodeUsersMap; | |||
| using NodeUsersMap = deprecated::api::NodeUsersMap; | |||
| using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>; | |||
| using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>; | |||
| @@ -277,7 +277,8 @@ class FuncGraphJTotalComputer final : public DepComputer { | |||
| bool SeekJ(const FuncGraphPtr &fg, size_t seen_num); | |||
| }; | |||
| class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager { | |||
| class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, | |||
| public deprecated::api::FuncGraphManager { | |||
| public: | |||
| explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true); | |||
| ~FuncGraphManager() { | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_BASE_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_BASE_H_ | |||
| #include <cstdint> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "mindapi/base/macros.h" | |||
| #include "mindapi/base/type_traits.h" | |||
| #include "mindapi/base/shared_ptr.h" | |||
| namespace mindspore { | |||
| class Base; | |||
| } | |||
| namespace mindspore::api { | |||
| /// \brief Base is the base class of many api classes, which provides basic interfaces. | |||
| class MIND_API Base { | |||
| public: | |||
| /// \brief Create an instance from the given implementation object. | |||
| /// | |||
| /// \param[in] impl The shared_ptr to the implementation object. | |||
| explicit Base(const std::shared_ptr<mindspore::Base> &impl); | |||
| /// \brief Destructor of Base. | |||
| virtual ~Base() = default; | |||
| /// \brief Get the id of this class. | |||
| /// | |||
| /// \return The id of this class. | |||
| static uint32_t ClassId(); | |||
| /// \brief Get the shared_ptr to the underly implementation object. | |||
| /// | |||
| /// \return The shared_ptr to the underly implementation object. | |||
| const std::shared_ptr<mindspore::Base> &impl() const { return impl_; } | |||
| /// \brief Get the string representation of this object. | |||
| /// | |||
| /// \return The string representation. | |||
| std::string ToString() const; | |||
| /// \brief Check whether this object is an instance of the given class. | |||
| /// | |||
| /// \return True if this object is an instance of the given class, false otherwise. | |||
| template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>> | |||
| inline bool isa() const { | |||
| return IsFromClassId(T::ClassId()); | |||
| } | |||
| /// \brief Cast this object to a pointer with the given pointer class. | |||
| /// | |||
| /// \return A non-null pointer if cast success, nullptr otherwise. | |||
| template <typename T, typename U = typename std::enable_if_t<is_wrapper_ptr<T>::value, typename T::element_type>> | |||
| inline T cast() { | |||
| if (isa<U>()) { | |||
| return MakeShared<U>(impl_); | |||
| } | |||
| return nullptr; | |||
| } | |||
| protected: | |||
| bool IsFromClassId(uint32_t class_id) const; | |||
| const std::shared_ptr<mindspore::Base> impl_; | |||
| }; | |||
| #define MIND_API_BASE_MEMBER(current_class) \ | |||
| explicit current_class(const std::shared_ptr<mindspore::Base> &impl); \ | |||
| ~current_class() override = default; \ | |||
| static uint32_t ClassId() | |||
| using BasePtr = SharedPtr<Base>; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_BASE_H_ | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_ | |||
| #include <cstdint> | |||
| #include <memory> | |||
| #include <sstream> | |||
| #include <utility> | |||
| #include "mindapi/base/macros.h" | |||
| namespace mindspore::api { | |||
| enum class LogLevel : uint8_t { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION }; | |||
| class LogWriterImpl; | |||
| /// \brief LogStream represents a stream to write log messages. | |||
| /// This class is not expected for directly use, use MS_LOG instead. | |||
| class LogStream { | |||
| public: | |||
| /// \brief Write log message to this LogStream. | |||
| /// | |||
| /// \param[in] value The object to be written. | |||
| template <typename T> | |||
| LogStream &operator<<(T &&value) noexcept { | |||
| (void)stream_.operator<<(std::forward<T>(value)); | |||
| return *this; | |||
| } | |||
| private: | |||
| friend class LogWriterImpl; | |||
| std::stringstream stream_; | |||
| }; | |||
| /// \brief LogWriter defines interface for log message output. | |||
| /// This class is not expected for directly use, use MS_LOG instead. | |||
| class MIND_API LogWriter { | |||
| public: | |||
| /// \brief Create a LogWriter with the given log level, file name, line number and function name. | |||
| /// | |||
| /// \param[in] level The log level. | |||
| /// \param[in] file The file name. | |||
| /// \param[in] line The line number. | |||
| /// \param[in] func The function name. | |||
| LogWriter(LogLevel level, const char *file, int line, const char *func); | |||
| /// \brief Destructor for LogWriter. | |||
| ~LogWriter(); | |||
| /// \brief Output log message from the input log stream. | |||
| /// | |||
| /// \param[in] stream The input log stream. | |||
| void operator<(const LogStream &stream) const noexcept; | |||
| /// \brief Output log message from the input log stream and then throw exception. | |||
| /// | |||
| /// \param[in] stream The input log stream. | |||
| void operator^(const LogStream &stream) const __attribute__((noreturn)); | |||
| /// \brief Check whether the given log level is enabled or not. | |||
| /// | |||
| /// \return True if the log level is enabled, false otherwise. | |||
| static bool IsEnabled(LogLevel level); | |||
| private: | |||
| std::unique_ptr<LogWriterImpl> impl_; | |||
| }; | |||
| #define MIND_LOG_STREAM mindspore::api::LogStream() | |||
| #define MIND_LOG_WRITER mindspore::api::LogWriter | |||
| #define MIND_LOG_LEVEL(L) mindspore::api::LogLevel::L | |||
| #define MIND_LOG_THROW(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) ^ MIND_LOG_STREAM | |||
| #define MIND_LOG_WRITE(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) < MIND_LOG_STREAM | |||
| #define MIND_LOG_IF(L) \ | |||
| if (MIND_LOG_WRITER::IsEnabled(MIND_LOG_LEVEL(L))) MIND_LOG_WRITE(L) | |||
| #define MIND_LOG_DEBUG MIND_LOG_IF(DEBUG) | |||
| #define MIND_LOG_INFO MIND_LOG_IF(INFO) | |||
| #define MIND_LOG_WARNING MIND_LOG_IF(WARNING) | |||
| #define MIND_LOG_ERROR MIND_LOG_IF(ERROR) | |||
| #define MIND_LOG_EXCEPTION MIND_LOG_THROW(EXCEPTION) | |||
| #define MIND_LOG(level) MIND_LOG_##level | |||
| #if !defined(MIND_LOG_NO_MS_LOG) && !defined(MS_LOG) | |||
| #define MS_LOG(level) MIND_LOG_##level | |||
| #endif | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_ | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_ | |||
| #if (defined(_WIN32) || defined(__WIN32__) || defined(WIN32) || defined(__CYGWIN__)) | |||
| #ifdef BUILDING_DLL | |||
| #define MIND_API __declspec(dllexport) | |||
| #else | |||
| #define MIND_API __declspec(dllimport) | |||
| #endif | |||
| #else | |||
| #define MIND_API __attribute__((visibility("default"))) | |||
| #endif | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_ | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_ | |||
| #include <cstdint> | |||
| #include <vector> | |||
| using ShapeVector = std::vector<int64_t>; | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_ | |||
| @@ -0,0 +1,180 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_ | |||
| #include <cstdint> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <ostream> | |||
| #include <functional> | |||
| namespace mindspore::api { | |||
| /// \brief SharedPtr wraps a std::shared_ptr and provides wrapper functions according the underlying implementation. | |||
| template <typename T> | |||
| class SharedPtr { | |||
| public: | |||
| using element_type = T; | |||
| constexpr SharedPtr() noexcept = default; | |||
| constexpr SharedPtr(std::nullptr_t) noexcept : SharedPtr() {} // NOLINT | |||
| template <typename U> | |||
| explicit SharedPtr(std::shared_ptr<U> &&ptr) : ptr_(std::move(ptr)) {} | |||
| template <typename U> | |||
| SharedPtr(const SharedPtr<U> &other) : ptr_(other.ptr_) {} | |||
| template <typename U> | |||
| SharedPtr(SharedPtr<U> &&other) : ptr_(std::move(other.ptr_)) {} | |||
| template <typename U> | |||
| SharedPtr &operator=(const SharedPtr<U> &other) { | |||
| ptr_ = other.ptr_; | |||
| return *this; | |||
| } | |||
| template <typename U> | |||
| SharedPtr &operator=(SharedPtr<U> &&other) { | |||
| ptr_ = std::move(other.ptr_); | |||
| return *this; | |||
| } | |||
| ~SharedPtr() = default; | |||
| std::uintptr_t addr() const { return (ptr_ == nullptr) ? 0 : reinterpret_cast<std::uintptr_t>(ptr_->impl().get()); } | |||
| element_type &operator*() const noexcept { return *ptr_; } | |||
| element_type *operator->() const noexcept { return ptr_.get(); } | |||
| element_type *get() const noexcept { return ptr_.get(); } | |||
| explicit operator bool() const { return addr() != 0; } | |||
| private: | |||
| template <typename U> | |||
| friend class SharedPtr; | |||
| std::shared_ptr<element_type> ptr_; | |||
| }; | |||
| template <typename T, typename U> | |||
| inline bool operator==(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept { | |||
| return a.addr() == b.addr(); | |||
| } | |||
| template <typename T> | |||
| inline bool operator==(const SharedPtr<T> &a, std::nullptr_t) noexcept { | |||
| return a.addr() == 0; | |||
| } | |||
| template <typename T> | |||
| inline bool operator==(std::nullptr_t, const SharedPtr<T> &a) noexcept { | |||
| return a.addr() == 0; | |||
| } | |||
| template <typename T, typename U> | |||
| inline bool operator!=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept { | |||
| return a.addr() != b.addr(); | |||
| } | |||
| template <typename T> | |||
| inline bool operator!=(const SharedPtr<T> &a, std::nullptr_t) noexcept { | |||
| return a.addr() != 0; | |||
| } | |||
| template <typename T> | |||
| inline bool operator!=(std::nullptr_t, const SharedPtr<T> &a) noexcept { | |||
| return a.addr() != 0; | |||
| } | |||
| template <typename T, typename U> | |||
| inline bool operator<(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept { | |||
| return a.addr() < b.addr(); | |||
| } | |||
| template <typename T> | |||
| inline bool operator<(const SharedPtr<T> &a, std::nullptr_t) noexcept { | |||
| return a.addr() < 0; | |||
| } | |||
| template <typename T> | |||
| inline bool operator<(std::nullptr_t, const SharedPtr<T> &a) noexcept { | |||
| // 'nullptr < ptr' is false only when ptr is nullptr. | |||
| return a.addr() != 0; | |||
| } | |||
| template <typename T, typename U> | |||
| inline bool operator>(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept { | |||
| return a.addr() > b.addr(); | |||
| } | |||
| template <typename T> | |||
| inline bool operator>(const SharedPtr<T> &a, std::nullptr_t) noexcept { | |||
| return a.addr() > 0; | |||
| } | |||
| template <typename T> | |||
| inline bool operator>(std::nullptr_t, const SharedPtr<T> &a) noexcept { | |||
| // 'nullptr > ptr' is always false. | |||
| return false; | |||
| } | |||
| template <typename T, typename U> | |||
| inline bool operator<=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept { | |||
| return a.addr() <= b.addr(); | |||
| } | |||
| template <typename T> | |||
| inline bool operator<=(const SharedPtr<T> &a, std::nullptr_t) noexcept { | |||
| return a.addr() <= 0; | |||
| } | |||
| template <typename T> | |||
| inline bool operator<=(std::nullptr_t, const SharedPtr<T> &a) noexcept { | |||
| // 'nullptr <= ptr' is always true. | |||
| return true; | |||
| } | |||
| template <typename T, typename U> | |||
| inline bool operator>=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept { | |||
| return a.addr() >= b.addr(); | |||
| } | |||
| template <typename T> | |||
| inline bool operator>=(const SharedPtr<T> &a, std::nullptr_t) noexcept { | |||
| return a.addr() >= 0; | |||
| } | |||
| template <typename T> | |||
| inline bool operator>=(std::nullptr_t, const SharedPtr<T> &a) noexcept { | |||
| // 'nullptr >= ptr' is true only when ptr is nullptr. | |||
| return a.addr() == 0; | |||
| } | |||
| template <typename T, typename U, typename V> | |||
| inline std::basic_ostream<U, V> &operator<<(std::basic_ostream<U, V> &os, const SharedPtr<T> &a) { | |||
| return (os << reinterpret_cast<void *>(a.addr())); | |||
| } | |||
| /// \brief Constructs an object of type T and wraps it in a SharedPtr. | |||
| /// | |||
| /// \param[in] args The parameter list for the constructor of T. | |||
| template <typename T, typename... Args> | |||
| inline SharedPtr<T> MakeShared(Args &&... args) { | |||
| auto ptr = std::make_shared<T>(std::forward<Args>(args)...); | |||
| return SharedPtr<T>(std::move(ptr)); | |||
| } | |||
| } // namespace mindspore::api | |||
| namespace std { | |||
| template <typename T> | |||
| struct hash<mindspore::api::SharedPtr<T>> { | |||
| size_t operator()(const mindspore::api::SharedPtr<T> &ptr) const noexcept { return static_cast<size_t>(ptr.addr()); } | |||
| }; | |||
| } // namespace std | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_ | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ | |||
| namespace mindspore { | |||
| /// \brief TypeId defines data type identifiers. | |||
| enum TypeId : int { | |||
| kTypeUnknown = 0, | |||
| // | |||
| // Meta types. | |||
| // | |||
| kMetaTypeBegin = kTypeUnknown, | |||
| kMetaTypeType, // Type | |||
| kMetaTypeAnything, | |||
| kMetaTypeObject, | |||
| kMetaTypeTypeType, // TypeType | |||
| kMetaTypeProblem, | |||
| kMetaTypeExternal, | |||
| kMetaTypeNone, | |||
| kMetaTypeNull, | |||
| kMetaTypeEllipsis, | |||
| kMetaTypeEnd, | |||
| // | |||
| // Object types | |||
| // | |||
| kObjectTypeBegin = kMetaTypeEnd, | |||
| kObjectTypeNumber, | |||
| kObjectTypeString, | |||
| kObjectTypeList, | |||
| kObjectTypeTuple, | |||
| kObjectTypeSlice, | |||
| kObjectTypeKeyword, | |||
| kObjectTypeTensorType, | |||
| kObjectTypeRowTensorType, | |||
| kObjectTypeSparseTensorType, | |||
| kObjectTypeUndeterminedType, | |||
| kObjectTypeClass, | |||
| kObjectTypeDictionary, | |||
| kObjectTypeFunction, | |||
| kObjectTypeJTagged, | |||
| kObjectTypeSymbolicKeyType, | |||
| kObjectTypeEnvType, | |||
| kObjectTypeRefKey, | |||
| kObjectTypeRef, | |||
| kObjectTypeEnd, | |||
| // | |||
| // Number Types | |||
| // | |||
| kNumberTypeBegin = kObjectTypeEnd, | |||
| kNumberTypeBool, | |||
| kNumberTypeInt, | |||
| kNumberTypeInt8, | |||
| kNumberTypeInt16, | |||
| kNumberTypeInt32, | |||
| kNumberTypeInt64, | |||
| kNumberTypeUInt, | |||
| kNumberTypeUInt8, | |||
| kNumberTypeUInt16, | |||
| kNumberTypeUInt32, | |||
| kNumberTypeUInt64, | |||
| kNumberTypeFloat, | |||
| kNumberTypeFloat16, | |||
| kNumberTypeFloat32, | |||
| kNumberTypeFloat64, | |||
| kNumberTypeComplex, | |||
| kNumberTypeComplex64, | |||
| kNumberTypeComplex128, | |||
| kNumberTypeInt4, | |||
| kNumberTypeGLUInt, | |||
| kNumberTypeEnd, | |||
| // | |||
| // Monad Types | |||
| // | |||
| kMonadTypeBegin = kNumberTypeEnd, | |||
| kObjectTypeMonad, | |||
| kObjectTypeUMonad, | |||
| kObjectTypeIOMonad, | |||
| kMonadTypeEnd, | |||
| // | |||
| // Sparse Types | |||
| // | |||
| // Sparse types is placed at the end of enum, | |||
| // in order to keep fit with the type of existing model on the lite side. | |||
| kSparseTypeBegin = kMonadTypeEnd, | |||
| kObjectTypeCSRTensorType, | |||
| kSparseTypeEnd | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <type_traits> | |||
| #include "mindapi/base/shared_ptr.h" | |||
| namespace mindspore::api { | |||
| template <typename T> | |||
| struct is_wrapper_ptr : public std::false_type {}; | |||
| template <typename T> | |||
| struct is_wrapper_ptr<SharedPtr<T>> : public std::true_type {}; | |||
| template <typename T> | |||
| struct is_shared_ptr : public std::false_type {}; | |||
| template <typename T> | |||
| struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {}; | |||
| template <typename T> | |||
| struct is_vector : public std::false_type {}; | |||
| template <typename T, typename A> | |||
| struct is_vector<std::vector<T, A>> : public std::true_type {}; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_ | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_ | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/ir/common.h" | |||
| #include "mindapi/ir/shape.h" | |||
| #include "mindapi/ir/type.h" | |||
| #include "mindapi/ir/value.h" | |||
| namespace mindspore::api { | |||
| /// \brief AbstractBase defines base interfaces for abstract of an anf node. | |||
| class MIND_API AbstractBase : public Base { | |||
| public: | |||
| MIND_API_BASE_MEMBER(AbstractBase); | |||
| /// \brief Clone an abstract from this abstract. | |||
| /// | |||
| /// \return A pointer to the cloned abstract. | |||
| AbstractBasePtr Clone() const; | |||
| /// \brief Get the abstract type. | |||
| /// | |||
| /// \return A pointer to the Type. | |||
| TypePtr type() const; | |||
| /// \brief Get the abstract value. | |||
| /// | |||
| /// \return A pointer to the Value. | |||
| ValuePtr value() const; | |||
| /// \brief Set the type for this abstract. | |||
| /// | |||
| /// \param[in] type The type to be set. | |||
| void set_type(const TypePtr &type); | |||
| /// \brief Set the value for this abstract. | |||
| /// | |||
| /// \param[in] value The value to be set. | |||
| void set_value(const ValuePtr &value); | |||
| }; | |||
| /// \brief AbstractTensor describes a tensor's type, shape and value. | |||
| class MIND_API AbstractTensor : public AbstractBase { | |||
| public: | |||
| MIND_API_BASE_MEMBER(AbstractTensor); | |||
| /// \brief Create AbstractTensor from the given type and shape. | |||
| /// | |||
| /// \param[in] type The data type id of the tensor. | |||
| /// \param[in] shape The shape of the tensor. | |||
| AbstractTensor(TypeId type, const ShapeVector &shape); | |||
| /// \brief Get the element abstract. | |||
| /// | |||
| /// \return A pointer to the element abstract. | |||
| AbstractBasePtr element() const; | |||
| /// \brief Get the shape of the abstract. | |||
| /// | |||
| /// \return A pointer to the shape. | |||
| ShapePtr shape() const; | |||
| }; | |||
| using AbstractTensorPtr = SharedPtr<AbstractTensor>; | |||
| /// \brief AbstractSequence describes the abstract for a tuple or list. | |||
| class MIND_API AbstractSequence : public AbstractBase { | |||
| public: | |||
| MIND_API_BASE_MEMBER(AbstractSequence); | |||
| /// \brief Get element abstracts. | |||
| /// | |||
| /// \return A vector of element abstracts. | |||
| AbstractBasePtrList elements() const; | |||
| }; | |||
| using AbstractSequencePtr = SharedPtr<AbstractSequence>; | |||
| /// \brief AbstractTuple describes the abstract for a tuple. | |||
| class MIND_API AbstractTuple : public AbstractSequence { | |||
| public: | |||
| MIND_API_BASE_MEMBER(AbstractTuple); | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_ | |||
| @@ -0,0 +1,232 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_ANF_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/ir/common.h" | |||
| #include "mindapi/ir/abstract.h" | |||
| #include "mindapi/ir/primitive.h" | |||
| #include "mindapi/ir/value.h" | |||
| namespace mindspore::api { | |||
| /// \brief AnfNode is the basic class of the IR graph node. | |||
| class MIND_API AnfNode : public Base { | |||
| public: | |||
| MIND_API_BASE_MEMBER(AnfNode); | |||
| /// \brief Obtain detailed information about scope namespace. | |||
| /// | |||
| /// \return Detailed information about scope namespace. | |||
| std::string fullname_with_scope() const; | |||
| /// \brief Obtain the inferred abstract value of this AnfNode. | |||
| /// | |||
| /// \return The inferred abstract value. | |||
| AbstractBasePtr abstract() const; | |||
| /// \brief Set the abstract value of this AnfNode. | |||
| /// | |||
| /// \param[in] abs New abstract value. | |||
| void set_abstract(const AbstractBasePtr &abs); | |||
| }; | |||
| /// \brief CNode represents a compute node with a set of input nodes. | |||
| class MIND_API CNode : public AnfNode { | |||
| public: | |||
| MIND_API_BASE_MEMBER(CNode); | |||
| /// \brief Get the number of inputs. | |||
| /// | |||
| /// \return The number of inputs in this CNode. | |||
| size_t size() const; | |||
| /// \brief Get the input node of the given index. | |||
| /// | |||
| /// \param[in] i The given index. | |||
| /// | |||
| /// \return The input node of the given index. | |||
| AnfNodePtr input(size_t i) const; | |||
| /// \brief Get the input nodes. | |||
| /// | |||
| /// \return The input nodes of this CNode. | |||
| std::vector<AnfNodePtr> inputs() const; | |||
| /// \brief Set the input nodes for this CNode. | |||
| /// | |||
| /// \param[in] inputs Input nodes. | |||
| void set_inputs(const std::vector<AnfNodePtr> &inputs); | |||
| /// \brief Add an input node to this CNode. | |||
| /// | |||
| /// \param[in] input the input node to be added. | |||
| void add_input(const AnfNodePtr &input); | |||
| /// \brief Set fullname_with_scope for this CNode. | |||
| /// | |||
| /// \param[in] full_name The fullname_with_scope. | |||
| void set_fullname_with_scope(const std::string &full_name); | |||
| /// \brief Add a new attribute to this CNode. | |||
| /// | |||
| /// \param[in] name The name of the new attribute. | |||
| /// \param[in] attr The value of the new attribute. | |||
| void AddAttr(const std::string &name, const ValuePtr &attr); | |||
| /// \brief Erase the attribute with the given name. | |||
| /// | |||
| /// \param[in] name The name of attribute. | |||
| void EraseAttr(const std::string &name); | |||
| /// \brief Get the attribute with the given name. | |||
| /// | |||
| /// \param[in] name The name of attribute. | |||
| /// \return Attribute. | |||
| ValuePtr GetAttr(const std::string &name) const; | |||
| }; | |||
| using CNodePtr = SharedPtr<CNode>; | |||
| /// \brief Parameter represents the parameter inputs of a function. | |||
| class MIND_API Parameter : public AnfNode { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Parameter); | |||
| /// \brief Get the name of this Parameter. | |||
| /// | |||
| /// \return The name. | |||
| std::string name() const; | |||
| /// \brief Set the name of this Parameter. | |||
| /// | |||
| /// \param[in] The name. | |||
| void set_name(const std::string &name); | |||
| /// \brief Check if there is a default parameter. | |||
| /// | |||
| /// \return True if this Parameter has a default parameter, otherwise false. | |||
| bool has_default() const; | |||
| /// \brief Set the default parameter. | |||
| /// | |||
| /// \param[in] param The default parameter. | |||
| void set_default_param(const ValuePtr ¶m); | |||
| /// \brief Get the default parameter. | |||
| /// | |||
| /// \return The default parameter. | |||
| ValuePtr default_param() const; | |||
| }; | |||
| using ParameterPtr = SharedPtr<Parameter>; | |||
| /// \brief ValueNode is a graph node that hold a value. | |||
| class MIND_API ValueNode : public AnfNode { | |||
| public: | |||
| MIND_API_BASE_MEMBER(ValueNode); | |||
| /// \brief Create ValueNode with the given value. | |||
| /// | |||
| /// \param[in] value The value of this ValueNode. | |||
| explicit ValueNode(const ValuePtr &value); | |||
| /// \brief Get the value of this ValueNode. | |||
| /// | |||
| /// \return The value. | |||
| ValuePtr value() const; | |||
| }; | |||
| using ValueNodePtr = SharedPtr<ValueNode>; | |||
| // === ANF utility functions === // | |||
| /// \brief Create a ValueNode with the given value. | |||
| /// | |||
| /// \param[in] value The given value. | |||
| /// | |||
| /// \return The created ValueNode. | |||
| inline ValueNodePtr NewValueNode(const ValuePtr &value) { return MakeShared<ValueNode>(value); } | |||
| /// \brief Create a ValueNode with the given primitive type value. | |||
| /// | |||
| /// \param[in] value The given primitive type value. | |||
| /// | |||
| /// \return The created ValueNode. | |||
| template <typename T> | |||
| inline ValueNodePtr NewValueNode(T value) { | |||
| return NewValueNode(MakeValue(value)); | |||
| } | |||
| /// \brief Get the value from a node if it is a ValueNode. | |||
| /// | |||
| /// \param[in] node The node which may hold a value. | |||
| /// | |||
| /// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set. | |||
| inline ValuePtr GetValueNode(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return value_node->value(); | |||
| } | |||
| /// \brief Get the value with the given type from a node if it is a ValueNode. | |||
| /// | |||
| /// \param[in] node The node which may hold a value. | |||
| /// | |||
| /// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set, or value type is mismatch. | |||
| template <typename T, typename = typename std::enable_if_t< | |||
| is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>> | |||
| inline T GetValueNode(const AnfNodePtr &node) { | |||
| auto value = GetValueNode(node); | |||
| if (value == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return value->cast<T>(); | |||
| } | |||
| /// \brief Check whether the given node is a cnode with the given Primitive as the first input. | |||
| /// | |||
| /// \param[in] node The given node to be checked. | |||
| /// \param[in] prim The Primitive value, nullptr means match any Primitive. | |||
| /// | |||
| /// \return True if the node is cnode and the first input is the given Primitive, false otherwise. | |||
| MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr); | |||
| /// \brief Check whether the given node is a ValueNode with the given Primitive. | |||
| /// | |||
| /// \param[in] node The given node to be checked. | |||
| /// \param[in] prim The Primitive value. | |||
| /// | |||
| /// \return True if the given node is a ValueNode with the given Primitive, false otherwise. | |||
| MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim); | |||
| /// \brief Check if a node is a data node. | |||
| /// Some nodes may be used internally to pass some non-data states, those nodes are not data nodes. | |||
| /// | |||
| /// \param[in] node The node to be checked. | |||
| /// | |||
| /// \return True if the node is a data node, false otherwise. | |||
| MIND_API bool IsDataNode(const AnfNodePtr &node); | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_ANF_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_COMMON_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_COMMON_H_ | |||
| #include <vector> | |||
| #include "mindapi/base/shared_ptr.h" | |||
| namespace mindspore::api { | |||
| class AnfNode; | |||
| using AnfNodePtr = SharedPtr<AnfNode>; | |||
| using AnfNodePtrList = std::vector<AnfNodePtr>; | |||
| class Value; | |||
| using ValuePtr = SharedPtr<Value>; | |||
| class Primitive; | |||
| using PrimitivePtr = SharedPtr<Primitive>; | |||
| class Type; | |||
| using TypePtr = SharedPtr<Type>; | |||
| class AbstractBase; | |||
| using AbstractBasePtr = SharedPtr<AbstractBase>; | |||
| using AbstractBasePtrList = std::vector<AbstractBasePtr>; | |||
| class Shape; | |||
| using ShapePtr = SharedPtr<Shape>; | |||
| class FuncGraph; | |||
| using FuncGraphPtr = SharedPtr<FuncGraph>; | |||
| class FuncGraphManager; | |||
| using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_ | |||
| @@ -0,0 +1,193 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/ir/common.h" | |||
| #include "mindapi/ir/anf.h" | |||
| #include "mindapi/ir/primitive.h" | |||
| #include "mindapi/ir/value.h" | |||
| #include "mindapi/ir/utils.h" | |||
| namespace mindspore { | |||
| class FuncGraphManager; | |||
| } | |||
| namespace mindspore::api { | |||
| /// \brief FuncGraph defines interface for a function graph. | |||
| class MIND_API FuncGraph : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(FuncGraph); | |||
| /// \brief Get the input parameters. | |||
| /// | |||
| /// \return Input parameters of this graph. | |||
| std::vector<AnfNodePtr> get_inputs() const; | |||
| /// \brief Get all parameters. | |||
| /// | |||
| /// \return All parameters of this graph. | |||
| std::vector<AnfNodePtr> parameters() const; | |||
| /// \brief Adds a parameter to this graph. | |||
| /// | |||
| /// \param[in] p The parameter to be added. | |||
| void add_parameter(const ParameterPtr &p); | |||
| /// \brief Adds a new parameter to this graph. | |||
| /// | |||
| /// \return The new added parameter. | |||
| ParameterPtr add_parameter(); | |||
| /// \brief Get the output node. | |||
| /// | |||
| /// \return The output node, nullptr if output not set. | |||
| AnfNodePtr output() const; | |||
| /// \brief Get the return CNode. | |||
| /// | |||
| /// \return The return CNode, nullptr if no return node. | |||
| CNodePtr get_return() const; | |||
| /// \brief Set the output node. | |||
| /// | |||
| /// \param[in] value The output node to be set. | |||
| /// \param[in] force_new_ret If true, a new return node is always created. | |||
| void set_output(const AnfNodePtr &value, bool force_new_ret = false); | |||
| /// \brief Set the return node. | |||
| /// | |||
| /// \param[in] cnode The return CNode to be set. | |||
| void set_return(const CNodePtr &cnode); | |||
| /// \brief Creates a new CNode in this graph. | |||
| /// | |||
| /// \param[in] inputs The input nodes of the new CNode. | |||
| /// | |||
| /// \return The created CNode. | |||
| CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()); | |||
| /// \brief Creates a new primitive CNode in this graph. | |||
| /// | |||
| /// \param[in] primitive The primitive of the new CNode. | |||
| /// \param[in] prim_inputs The argument inputs of the primitive CNode. | |||
| /// | |||
| /// \return The created primitive CNode. | |||
| CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs); | |||
| /// \brief Get all nodes in this graph. | |||
| /// | |||
| /// \return All nodes in this graph. | |||
| std::vector<AnfNodePtr> nodes() const; | |||
| /// \brief Check whether an attribute is set for this graph. | |||
| /// | |||
| /// \param[in] key The attribute key (name). | |||
| /// | |||
| /// \return True if the attribute with the given key is set, false otherwise. | |||
| bool has_attr(const std::string &key) const; | |||
| /// \brief Get an attribute value by its key. | |||
| /// | |||
| /// \param[in] key The attribute key (name). | |||
| /// | |||
| /// \return The attribute value for the given key, nullptr if attribute not found. | |||
| ValuePtr get_attr(const std::string &key) const; | |||
| /// \brief Set an attribute value. | |||
| /// | |||
| /// \param[in] key The attribute key (name). | |||
| /// \param[in] value The attribute value. | |||
| void set_attr(const std::string &key, const ValuePtr &value); | |||
| /// \brief Get the manager for this graph. | |||
| /// | |||
| /// \return The manager of this graph, nullptr if not set. | |||
| FuncGraphManagerPtr manager() const; | |||
| /// \brief Creates an empty function graph. | |||
| /// | |||
| /// \return The created function graph. | |||
| static FuncGraphPtr Create(); | |||
| /// \brief Topological sort a graph from the given end node. | |||
| /// | |||
| /// \param[in] node The end node of the graph to be sorted. | |||
| /// | |||
| /// \return The sorted nodes. | |||
| static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node); | |||
| }; | |||
| /// \brief FuncGraphManager defines interface for function graph management. | |||
| class MIND_API FuncGraphManager { | |||
| public: | |||
| /// \brief Create FuncGraphManager with the given implementor object. | |||
| /// | |||
| /// \param[in] impl The pointer to the implementor object. | |||
| explicit FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl); | |||
| /// \brief Get the shared_ptr to the underly implementation object. | |||
| /// | |||
| /// \return The shared_ptr to the underly implementation object. | |||
| const std::shared_ptr<mindspore::FuncGraphManager> &impl() const { return impl_; } | |||
| /// \brief Replace an old node with a new node, related edges are all updated. | |||
| /// | |||
| /// \param[in] old_node The old node to be replaced. | |||
| /// \param[in] new_node The new node that replace the old one. | |||
| /// | |||
| /// \return True if the node is successfully replaced, false otherwise. | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| /// \brief Change an existed edge by replace its input node. | |||
| /// | |||
| /// \param[in] node The output node of the edge. | |||
| /// \param[in] index The input index in output node. | |||
| /// \param[in] value The new input node of the edge. | |||
| void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | |||
| /// \brief Adds a new edge between the given two nodes. | |||
| /// | |||
| /// \param[in] node The output node of the edge. | |||
| /// \param[in] value The input node of the edge. | |||
| void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value); | |||
| /// \brief Find users of the given node. | |||
| /// | |||
| /// \param[in] node The node. | |||
| /// | |||
| /// \return Users of the given node, empty if user not found. | |||
| std::vector<std::pair<AnfNodePtr, int>> GetUsers(const AnfNodePtr &node) const; | |||
| /// \brief Manage the give function graph. | |||
| /// | |||
| /// \param[in] func_graph The function graph to be managed. | |||
| /// \param[in] manage If true, the created manager will be set in the graph. | |||
| /// | |||
| /// \return The manager that manages the given function graph. | |||
| static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true); | |||
| private: | |||
| const std::shared_ptr<mindspore::FuncGraphManager> impl_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_ | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/ir/common.h" | |||
| #include "mindapi/ir/value.h" | |||
| namespace mindspore::api { | |||
| /// \brief Primitive defines a primitive operator. | |||
| class MIND_API Primitive : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Primitive); | |||
| /// \brief Create primitive with the given name. | |||
| /// | |||
| /// \param[in] name The primitive name. | |||
| explicit Primitive(const std::string &name); | |||
| /// \brief Get name of the primitive. | |||
| /// | |||
| /// \return The name of primitive. | |||
| const std::string &name() const; | |||
| /// \brief Add attribute to primitive. | |||
| /// | |||
| /// \param[in] name The attribute name. | |||
| /// \param[in] attr The attribute value. | |||
| /// \return The primitive to which attribute has been added. | |||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr); | |||
| /// \brief Add attributes by using a map, all elements of the map will be added to this primitive. | |||
| /// | |||
| /// \param[in] attrs The attribute map needs to be added in the primitive attribute. | |||
| /// \return The primitive to which attribute has been added. | |||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs); | |||
| /// \brief Erase attribute to the primitive attribute map. | |||
| /// | |||
| /// \param[in] name The attribute name. | |||
| void EraseAttr(const std::string &name); | |||
| /// \brief Get attribute value by name. | |||
| /// | |||
| /// \param[in] name the attribute name. | |||
| /// \return The value of the attribute, null if attribute name not found. | |||
| ValuePtr GetAttr(const std::string &name) const; | |||
| /// \brief Check If Primitive has an attribute with then given name. | |||
| /// | |||
| /// \param[in] name The attribute name. | |||
| /// \return True if there is an attribute with the given name, otherwise false. | |||
| bool HasAttr(const std::string &name) const; | |||
| /// \brief Get all attributes of this primitive as a map. | |||
| /// | |||
| /// \return The attribute map of this primitive. | |||
| std::unordered_map<std::string, ValuePtr> attrs() const; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_ | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_ | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/base/shape_vector.h" | |||
| #include "mindapi/ir/common.h" | |||
| namespace mindspore::api { | |||
| /// \brief Shape defines dimensions of a tensor. | |||
| class MIND_API Shape : public Base { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Shape); | |||
| /// \brief Get the shape dimensions. | |||
| /// | |||
| /// \return The shape dimensions. | |||
| const ShapeVector &shape() const; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_ | |||
| @@ -0,0 +1,93 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_ | |||
| #include <cstdint> | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/base/shape_vector.h" | |||
| #include "mindapi/base/type_id.h" | |||
| #include "mindapi/ir/common.h" | |||
| #include "mindapi/ir/value.h" | |||
| namespace mindspore::api { | |||
| /// \brief Tensor represents a multi-dimensional array of elements. | |||
| class MIND_API Tensor : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Tensor); | |||
| /// \brief Create a lazy allocated tensor. | |||
| /// | |||
| /// \param[in] data_type [TypeId] Data type of the tensor. | |||
| /// \param[in] shape The shape represented by ShapeVector of the tensor. | |||
| Tensor(TypeId data_type, const ShapeVector &shape); | |||
| /// \brief Create a tensor with input data buffer. | |||
| /// | |||
| /// \param[in] data_type [TypeId] Data type of the tensor. | |||
| /// \param[in] shape The shape represented by ShapeVector of the tensor. | |||
| /// \param[in] data The input data to be copied into tensor. | |||
| /// \param[in] data_len The length of data in bytes. | |||
| Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len); | |||
| /// \brief Get the shape of the tensor. | |||
| /// The shape of a tensor is stored in a vector<int64_t>. Each element of the | |||
| /// vector represents the size of a dimension of the tensor. The order of each | |||
| /// element in the vector is the same as the the dimension's order it represents. | |||
| /// | |||
| /// \return A vector<int64_t> which represents the shape of the tensor. | |||
| const ShapeVector &shape() const; | |||
| /// \brief Set the shape of tensor. | |||
| /// | |||
| /// \param[in] shape The shape to be set. | |||
| void set_shape(const ShapeVector &shape); | |||
| /// \brief Get the data type of the tensor. | |||
| /// | |||
| /// \return The data type of the tensor. | |||
| TypeId data_type() const; | |||
| /// \brief Set the data type of the tensor. | |||
| /// | |||
| /// \param[in] data_type The data type to be set. | |||
| void set_data_type(const TypeId data_type); | |||
| /// \brief Get The pointer to the underlying memory block for data storage. | |||
| /// | |||
| /// \return The pointer to the underlying data. | |||
| const void *data() const; | |||
| /// \brief Get The pointer to the underlying memory block for data storage. | |||
| /// | |||
| /// \return The pointer to the underlying data. | |||
| void *data(); | |||
| /// \brief Get tensor data size. | |||
| /// | |||
| /// \return The total number of elements in the tensor. | |||
| int DataSize() const; | |||
| /// \brief Get tensor data size in bytes. | |||
| /// | |||
| /// \return The total number of bytes for the tensor data. | |||
| std::size_t Size() const; | |||
| }; | |||
| using TensorPtr = SharedPtr<Tensor>; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_TYPE_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_TYPE_H_ | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/base/type_id.h" | |||
| #include "mindapi/ir/common.h" | |||
| #include "mindapi/ir/value.h" | |||
| namespace mindspore::api { | |||
| /// \brief Type defines the type of a value. | |||
| class MIND_API Type : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Type); | |||
| /// \brief Get the id of the Type object. | |||
| /// | |||
| /// \return The id of the Type object. | |||
| TypeId type_id() const; | |||
| /// \brief Get the number type of the Type object. | |||
| /// | |||
| /// \return The number type of this Type object, kTypeUnknown if this is not a number type. | |||
| TypeId number_type() const; | |||
| /// \brief Get the Type according to a TypeId. | |||
| /// | |||
| /// \param[in] id The id of the type. | |||
| /// | |||
| /// \return The pointer to the Type. | |||
| static TypePtr GetType(TypeId id); | |||
| /// \brief Get data size in bytes for the type according to a TypeId. | |||
| /// | |||
| /// \param[in] id The id of the type. | |||
| /// | |||
| /// \return The data size in bytes for the Type. | |||
| static size_t GetSize(TypeId id); | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_TYPE_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_UTILS_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_UTILS_H_ | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/base/shared_ptr.h" | |||
| #include "mindapi/base/type_traits.h" | |||
| #include "mindapi/ir/anf.h" | |||
| #include "mindapi/ir/value.h" | |||
| #include "mindapi/ir/func_graph.h" | |||
| namespace mindspore::api::utils { | |||
| /// \brief Check whether the given object is an instance of the given class. | |||
| /// | |||
| /// \param[in] ptr The pointer to the given object. | |||
| /// | |||
| /// \return True if the pointer is not null and the object is an instance of the given class, false otherwise. | |||
| template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>> | |||
| bool isa(const BasePtr &ptr) { | |||
| if (ptr == nullptr) { | |||
| return false; | |||
| } | |||
| return ptr->isa<T>(); | |||
| } | |||
| /// \brief Cast the given object pointer to a pointer with the given class. | |||
| /// | |||
| /// \param[in] ptr The pointer to the object to casted. | |||
| /// | |||
| /// \return A non-null pointer if the input pointer is not null and cast success, nullptr otherwise. | |||
| template <typename T, typename = typename std::enable_if_t<is_wrapper_ptr<T>::value, T>> | |||
| T cast(const BasePtr &ptr) { | |||
| if (ptr == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return ptr->cast<T>(); | |||
| } | |||
| /// \brief Make a copy from the given function graph. | |||
| /// | |||
| /// \param[in] func_graph The graph to be cloned. | |||
| /// | |||
| /// \return The cloned graph. | |||
| MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph); | |||
| /// \brief Get pad mode id from a value holds the pad mode name or id. | |||
| /// | |||
| /// \param[in] value The value holds the pad mode name or id. | |||
| /// \param[in] is_upper Indicates whether the name is uppercase or lowercase, default is false for lowercase. | |||
| /// | |||
| /// \return The pad mode id. | |||
| MIND_API int64_t GetPadMode(const ValuePtr &value, bool is_upper = false); | |||
| } // namespace mindspore::api::utils | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_UTILS_H_ | |||
| @@ -0,0 +1,270 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <type_traits> | |||
| #include "mindapi/base/base.h" | |||
| #include "mindapi/ir/common.h" | |||
| namespace mindspore::api { | |||
| template <typename T> | |||
| struct ImmTrait {}; | |||
| #define MIND_API_IMM_TRAIT(typeimm, prototype) \ | |||
| template <> \ | |||
| struct ImmTrait<prototype> { \ | |||
| using type = SharedPtr<typeimm>; \ | |||
| } | |||
| /// \brief Value represents a value in expression. | |||
| class MIND_API Value : public Base { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Value); | |||
| /// \brief Get the type of this Value. | |||
| /// | |||
| /// \return The type. | |||
| TypePtr type() const; | |||
| /// \brief Get the abstract of this Value. | |||
| /// | |||
| /// \return Abstract of this Value. | |||
| AbstractBasePtr ToAbstract() const; | |||
| }; | |||
| /// \brief ValueSequence represents a sequence of values. | |||
| class MIND_API ValueSequence : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(ValueSequence); | |||
| /// \brief Get the size of this ValueSequence. | |||
| /// | |||
| /// \return The size as the number of elements. | |||
| std::size_t size() const; | |||
| /// \brief Get the list of values in this ValueSequence. | |||
| /// | |||
| /// \return The list of element values. | |||
| std::vector<ValuePtr> value() const; | |||
| }; | |||
| using ValueSequencePtr = SharedPtr<ValueSequence>; | |||
| /// \brief ValueTuple represents a value tuple. | |||
| class MIND_API ValueTuple : public ValueSequence { | |||
| public: | |||
| MIND_API_BASE_MEMBER(ValueTuple); | |||
| /// \brief Constructor of ValueTuple. | |||
| /// | |||
| /// \param[in] elements The elements of the tuple. | |||
| explicit ValueTuple(const std::vector<ValuePtr> &elements); | |||
| }; | |||
| using ValueTuplePtr = SharedPtr<ValueTuple>; | |||
| /// \brief StringImm defines a Value whose type is string. | |||
| class MIND_API StringImm : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(StringImm); | |||
| /// \brief Create StringImm with the given string. | |||
| /// | |||
| /// \param[in] str The given string value. | |||
| explicit StringImm(const std::string &str); | |||
| /// \brief Get the string value of this StringImm. | |||
| /// | |||
| /// \return The string value of this StringImm. | |||
| const std::string &value() const; | |||
| }; | |||
| using StringImmPtr = SharedPtr<StringImm>; | |||
| MIND_API_IMM_TRAIT(StringImm, std::string); | |||
| /// \beief Scalar defines interface for scalar data. | |||
| class MIND_API Scalar : public Value { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Scalar); | |||
| }; | |||
| /// \beief BoolImm defines interface for bool data. | |||
| class MIND_API BoolImm : public Scalar { | |||
| public: | |||
| MIND_API_BASE_MEMBER(BoolImm); | |||
| /// \brief Create BoolImm with the given bool value. | |||
| /// | |||
| /// \param[in] b The given bool value. | |||
| explicit BoolImm(bool b); | |||
| /// \brief Get the bool value of this BoolImm. | |||
| /// | |||
| /// \return The bool value of this BoolImm. | |||
| bool value() const; | |||
| }; | |||
| using BoolImmPtr = SharedPtr<BoolImm>; | |||
| MIND_API_IMM_TRAIT(BoolImm, bool); | |||
| /// \beief IntegerImm defines interface for integer data. | |||
| class MIND_API IntegerImm : public Scalar { | |||
| public: | |||
| MIND_API_BASE_MEMBER(IntegerImm); | |||
| }; | |||
| /// \beief Int64Imm defines interface for int64 data. | |||
| class MIND_API Int64Imm : public IntegerImm { | |||
| public: | |||
| MIND_API_BASE_MEMBER(Int64Imm); | |||
| /// \brief Create Int64Imm with the given int64 value. | |||
| /// | |||
| /// \param[in] value The given bool value. | |||
| explicit Int64Imm(int64_t value); | |||
| /// \brief Get the int64 value of this Int64Imm. | |||
| /// | |||
| /// \return The int64 value of this Int64Imm. | |||
| int64_t value() const; | |||
| }; | |||
| using Int64ImmPtr = SharedPtr<Int64Imm>; | |||
| MIND_API_IMM_TRAIT(Int64Imm, int64_t); | |||
| /// \beief FloatImm defines interface for float data. | |||
| class MIND_API FloatImm : public Scalar { | |||
| public: | |||
| MIND_API_BASE_MEMBER(FloatImm); | |||
| }; | |||
| /// \beief FP32Imm defines interface for float32 data. | |||
| class MIND_API FP32Imm : public FloatImm { | |||
| public: | |||
| MIND_API_BASE_MEMBER(FP32Imm); | |||
| /// \brief Create FP32Imm with the given float value. | |||
| /// | |||
| /// \param[in] value The given float value. | |||
| explicit FP32Imm(float value); | |||
| /// \brief Get the float value of this FP32Imm. | |||
| /// | |||
| /// \return The float value of this FP32Imm. | |||
| float value() const; | |||
| }; | |||
| using FP32ImmPtr = SharedPtr<FP32Imm>; | |||
| MIND_API_IMM_TRAIT(FP32Imm, float); | |||
| // === Utility functions for Value === // | |||
| /// \brief brief Create a Value object from a primitive type value. | |||
| /// | |||
| /// \param[in] v The primitive type value. | |||
| /// | |||
| /// \return The created Value object with the given primitive type value. | |||
| template <typename T, typename U = typename ImmTrait<T>::type::element_type> | |||
| inline ValuePtr MakeValue(T v) { | |||
| return MakeShared<U>(v); | |||
| } | |||
| /// \brief brief Create a StringImm Value object from a C string. | |||
| /// | |||
| /// \param[in] s The C string. | |||
| /// | |||
| /// \return The created StringImm Value object. | |||
| inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); } | |||
| /// \brief brief Create a Int64Imm Value object from a int value. | |||
| /// | |||
| /// \param[in] i The int value. | |||
| /// | |||
| /// \return The created Int64Imm Value object. | |||
| inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); } | |||
| /// \brief brief Create a ValueSequence object from a vector of values. | |||
| /// | |||
| /// \param[in] values The vector of values. | |||
| /// | |||
| /// \return The created ValueSequence object. | |||
| inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); } | |||
| /// \brief Create a ValueSequence object from a vector of primitive type values. | |||
| /// | |||
| /// \param[in] values The vector of primitive values. | |||
| /// | |||
| /// \return The created ValueSequence object. | |||
| template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>> | |||
| inline ValuePtr MakeValue(const T &values) { | |||
| std::vector<ValuePtr> value_vector; | |||
| value_vector.reserve(values.size()); | |||
| for (auto &value : values) { | |||
| value_vector.emplace_back(MakeValue(value)); | |||
| } | |||
| return MakeShared<ValueTuple>(value_vector); | |||
| } | |||
| /// \brief brief Get primitive type value from a Value object. | |||
| /// | |||
| /// \param[in] value The pointer to the Value object. | |||
| /// | |||
| /// \return The primitive type value of the Value object. | |||
| template <typename T, typename U = typename ImmTrait<T>::type> | |||
| inline T GetValue(const ValuePtr &value) { | |||
| if (value == nullptr) { | |||
| return T(); | |||
| } | |||
| U imm = value->cast<U>(); | |||
| if (imm == nullptr) { | |||
| return T(); | |||
| } | |||
| return imm->value(); | |||
| } | |||
| /// \brief brief Get primitive element values from a ValueSequeue object. | |||
| /// | |||
| /// \param[in] value The pointer to the ValueSequeue object. | |||
| /// | |||
| /// \return The primitive type values as a vector. | |||
| template <typename T, typename S = typename std::decay_t<T>, | |||
| typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>> | |||
| inline std::vector<U> GetValue(const ValuePtr &value) { | |||
| if (value == nullptr) { | |||
| return {}; | |||
| } | |||
| auto seq = value->cast<ValueSequencePtr>(); | |||
| if (seq == nullptr) { | |||
| return {}; | |||
| } | |||
| auto elements = seq->value(); | |||
| std::vector<U> result; | |||
| result.reserve(elements.size()); | |||
| for (auto &e : elements) { | |||
| result.emplace_back(GetValue<U>(e)); | |||
| } | |||
| return result; | |||
| } | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_ | |||
| @@ -0,0 +1,84 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/abstract.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/value.h" | |||
| namespace mindspore::api { | |||
| using TypeImpl = mindspore::Type; | |||
| using ValueImpl = mindspore::Value; | |||
| using AbstractBaseImpl = mindspore::abstract::AbstractBase; | |||
| MIND_API_BASE_IMPL(AbstractBase, AbstractBaseImpl, Base); | |||
| AbstractBasePtr AbstractBase::Clone() const { | |||
| auto abs = ToRef<AbstractBaseImpl>(impl_).Clone(); | |||
| return ToWrapper<AbstractBase>(abs); | |||
| } | |||
| TypePtr AbstractBase::type() const { | |||
| auto t = ToRef<AbstractBaseImpl>(impl_).BuildType(); | |||
| return ToWrapper<Type>(t); | |||
| } | |||
| ValuePtr AbstractBase::value() const { | |||
| auto v = ToRef<AbstractBaseImpl>(impl_).BuildValue(); | |||
| return ToWrapper<Value>(v); | |||
| } | |||
| void AbstractBase::set_type(const TypePtr &type) { | |||
| auto type_impl = ToImpl<TypeImpl>(type); | |||
| ToRef<AbstractBaseImpl>(impl_).set_type(type_impl); | |||
| } | |||
| void AbstractBase::set_value(const ValuePtr &value) { | |||
| auto value_impl = ToImpl<ValueImpl>(value); | |||
| ToRef<AbstractBaseImpl>(impl_).set_value(value_impl); | |||
| } | |||
| using AbstractTensorImpl = mindspore::abstract::AbstractTensor; | |||
| MIND_API_BASE_IMPL(AbstractTensor, AbstractTensorImpl, AbstractBase); | |||
| AbstractTensor::AbstractTensor(TypeId type, const ShapeVector &shape) | |||
| : AbstractBase(std::make_shared<AbstractTensorImpl>(mindspore::TypeIdToType(type), shape)) {} | |||
| AbstractBasePtr AbstractTensor::element() const { | |||
| auto abs = ToRef<AbstractTensorImpl>(impl_).element(); | |||
| return ToWrapper<AbstractBase>(abs); | |||
| } | |||
| ShapePtr AbstractTensor::shape() const { | |||
| auto s = ToRef<AbstractTensorImpl>(impl_).shape(); | |||
| return ToWrapper<Shape>(s); | |||
| } | |||
| using AbstractSequenceImpl = mindspore::abstract::AbstractSequeue; | |||
| MIND_API_BASE_IMPL(AbstractSequence, AbstractSequenceImpl, AbstractBase); | |||
| AbstractBasePtrList AbstractSequence::elements() const { | |||
| auto &impl_elements = ToRef<AbstractSequenceImpl>(impl_).elements(); | |||
| return ToWrapperVector<AbstractBase>(impl_elements); | |||
| } | |||
| using AbstractTupleImpl = mindspore::abstract::AbstractTuple; | |||
| MIND_API_BASE_IMPL(AbstractTuple, AbstractTupleImpl, AbstractSequence); | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,135 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/anf.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/value.h" | |||
| #include "ir/primitive.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore::api { | |||
| using ValueImpl = mindspore::Value; | |||
| using AnfNodeImpl = mindspore::AnfNode; | |||
| using PrimitiveImpl = mindspore::Primitive; | |||
| using AbstractBaseImpl = mindspore::abstract::AbstractBase; | |||
| MIND_API_BASE_IMPL(AnfNode, AnfNodeImpl, Base); | |||
| std::string AnfNode::fullname_with_scope() const { return ToRef<AnfNodeImpl>(impl_).fullname_with_scope(); } | |||
| AbstractBasePtr AnfNode::abstract() const { | |||
| const auto &abs = ToRef<AnfNodeImpl>(impl_).abstract(); | |||
| return ToWrapper<AbstractBase>(abs); | |||
| } | |||
| void AnfNode::set_abstract(const AbstractBasePtr &abs) { | |||
| ToRef<AnfNodeImpl>(impl_).set_abstract(ToImpl<AbstractBaseImpl>(abs)); | |||
| } | |||
| using CNodeImpl = mindspore::CNode; | |||
| MIND_API_BASE_IMPL(CNode, CNodeImpl, AnfNode); | |||
| size_t CNode::size() const { return ToRef<CNodeImpl>(impl_).size(); } | |||
| AnfNodePtr CNode::input(size_t i) const { | |||
| auto &input = ToRef<CNodeImpl>(impl_).input(i); | |||
| return ToWrapper<AnfNode>(input); | |||
| } | |||
| std::vector<AnfNodePtr> CNode::inputs() const { | |||
| auto &impl_inputs = ToRef<CNodeImpl>(impl_).inputs(); | |||
| return ToWrapperVector<AnfNode>(impl_inputs); | |||
| } | |||
| void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) { | |||
| auto impl_inputs = ToImplVector<AnfNodeImpl>(inputs); | |||
| ToRef<CNodeImpl>(impl_).set_inputs(impl_inputs); | |||
| } | |||
| void CNode::add_input(const AnfNodePtr &input) { | |||
| auto impl_input = ToImpl<AnfNodeImpl>(input); | |||
| MS_EXCEPTION_IF_NULL(impl_input); | |||
| ToRef<CNodeImpl>(impl_).add_input(impl_input); | |||
| } | |||
| void CNode::set_fullname_with_scope(const std::string &full_name) { | |||
| ToRef<CNodeImpl>(impl_).set_fullname_with_scope(full_name); | |||
| } | |||
| void CNode::AddAttr(const std::string &name, const ValuePtr &attr) { | |||
| auto impl_attr = ToImpl<ValueImpl>(attr); | |||
| MS_EXCEPTION_IF_NULL(impl_attr); | |||
| ToRef<CNodeImpl>(impl_).AddAttr(name, impl_attr); | |||
| } | |||
| void CNode::EraseAttr(const std::string &name) { ToRef<CNodeImpl>(impl_).EraseAttr(name); } | |||
| ValuePtr CNode::GetAttr(const std::string &name) const { | |||
| auto v = ToRef<CNodeImpl>(impl_).GetAttr(name); | |||
| return ToWrapper<Value>(v); | |||
| } | |||
| using ParameterImpl = mindspore::Parameter; | |||
| MIND_API_BASE_IMPL(Parameter, ParameterImpl, AnfNode); | |||
| std::string Parameter::name() const { return ToRef<ParameterImpl>(impl_).name(); } | |||
| void Parameter::set_name(const std::string &name) { ToRef<ParameterImpl>(impl_).set_name(name); } | |||
| bool Parameter::has_default() const { return ToRef<ParameterImpl>(impl_).has_default(); } | |||
| void Parameter::set_default_param(const ValuePtr ¶m) { | |||
| auto v = ToImpl<ValueImpl>(param); | |||
| ToRef<ParameterImpl>(impl_).set_default_param(v); | |||
| } | |||
| ValuePtr Parameter::default_param() const { | |||
| auto v = ToRef<ParameterImpl>(impl_).default_param(); | |||
| return ToWrapper<Value>(v); | |||
| } | |||
| using ValueNodeImpl = mindspore::ValueNode; | |||
| MIND_API_BASE_IMPL(ValueNode, ValueNodeImpl, AnfNode); | |||
| ValueNode::ValueNode(const ValuePtr &value) : AnfNode(std::make_shared<ValueNodeImpl>(ToImpl<ValueImpl>(value))) {} | |||
| ValuePtr ValueNode::value() const { | |||
| auto v = ToRef<ValueNodeImpl>(impl_).value(); | |||
| return ToWrapper<Value>(v); | |||
| } | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim) { | |||
| auto node_impl = ToImpl<AnfNodeImpl>(node); | |||
| auto prim_impl = ToImpl<PrimitiveImpl>(prim); | |||
| return mindspore::IsPrimitiveCNode(node_impl, prim_impl); | |||
| } | |||
| bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim) { | |||
| auto node_impl = ToImpl<AnfNodeImpl>(node); | |||
| auto prim_impl = ToImpl<PrimitiveImpl>(prim); | |||
| return mindspore::IsPrimitive(node_impl, prim_impl); | |||
| } | |||
| bool IsDataNode(const AnfNodePtr &node) { | |||
| auto node_impl = ToImpl<AnfNodeImpl>(node); | |||
| // We assume that node with monad abstract is not a data node. | |||
| return !HasAbstractMonad(node_impl); | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/base/base.h" | |||
| #include "base/base.h" | |||
| namespace mindspore::api { | |||
| Base::Base(const std::shared_ptr<mindspore::Base> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl_); } | |||
| uint32_t Base::ClassId() { return mindspore::Base::kTypeId; } | |||
| bool Base::IsFromClassId(uint32_t class_id) const { return impl_->IsFromTypeId(class_id); } | |||
| std::string Base::ToString() const { return impl_->ToString(); } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,170 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "mindapi/ir/func_graph.h" | |||
| #include "mindapi/src/helper.h" | |||
| #define USE_DEPRECATED_API | |||
| #include "ir/anf.h" | |||
| #include "ir/value.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/manager.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/graph_utils.h" | |||
| namespace mindspore::api { | |||
| using ValueImpl = mindspore::Value; | |||
| using AnfNodeImpl = mindspore::AnfNode; | |||
| using CNodeImpl = mindspore::CNode; | |||
| using PrimitiveImpl = mindspore::Primitive; | |||
| using ParameterImpl = mindspore::Parameter; | |||
| using FuncGraphImpl = mindspore::FuncGraph; | |||
| using FuncGraphManagerImpl = mindspore::FuncGraphManager; | |||
| MIND_API_BASE_IMPL(FuncGraph, FuncGraphImpl, Value); | |||
| std::vector<AnfNodePtr> FuncGraph::get_inputs() const { | |||
| auto &inputs = ToRef<FuncGraphImpl>(impl_).get_inputs(); | |||
| return ToWrapperVector<AnfNode>(inputs); | |||
| } | |||
| std::vector<AnfNodePtr> FuncGraph::parameters() const { | |||
| auto ¶ms = ToRef<FuncGraphImpl>(impl_).parameters(); | |||
| return ToWrapperVector<AnfNode>(params); | |||
| } | |||
| void FuncGraph::add_parameter(const ParameterPtr &p) { | |||
| auto param_impl = ToImpl<ParameterImpl>(p); | |||
| ToRef<FuncGraphImpl>(impl_).add_parameter(param_impl); | |||
| } | |||
| ParameterPtr FuncGraph::add_parameter() { | |||
| auto param_impl = ToRef<FuncGraphImpl>(impl_).add_parameter(); | |||
| return ToWrapper<Parameter>(param_impl); | |||
| } | |||
| AnfNodePtr FuncGraph::output() const { | |||
| auto output = ToRef<FuncGraphImpl>(impl_).output(); | |||
| return ToWrapper<AnfNode>(output); | |||
| } | |||
| CNodePtr FuncGraph::get_return() const { | |||
| auto ret = ToRef<FuncGraphImpl>(impl_).get_return(); | |||
| return ToWrapper<CNode>(ret); | |||
| } | |||
| void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { | |||
| auto output = ToImpl<AnfNodeImpl>(value); | |||
| ToRef<FuncGraphImpl>(impl_).set_output(output); | |||
| } | |||
| void FuncGraph::set_return(const CNodePtr &cnode) { | |||
| auto cnode_impl = ToImpl<CNodeImpl>(cnode); | |||
| ToRef<FuncGraphImpl>(impl_).set_return(cnode_impl); | |||
| } | |||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| auto inputs_impl = ToImplVector<AnfNodeImpl>(inputs); | |||
| auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(std::move(inputs_impl)); | |||
| return ToWrapper<CNode>(cnode_impl); | |||
| } | |||
| CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) { | |||
| auto prim_impl = ToImpl<PrimitiveImpl>(primitive); | |||
| auto prim_inputs_impl = ToImplVector<AnfNodeImpl>(prim_inputs); | |||
| auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(prim_impl, prim_inputs_impl); | |||
| return ToWrapper<CNode>(cnode_impl); | |||
| } | |||
| std::vector<AnfNodePtr> FuncGraph::nodes() const { | |||
| auto &nodes = ToRef<FuncGraphImpl>(impl_).nodes(); | |||
| return ToWrapperVector<AnfNode>(nodes); | |||
| } | |||
| bool FuncGraph::has_attr(const std::string &key) const { return ToRef<FuncGraphImpl>(impl_).has_attr(key); } | |||
| ValuePtr FuncGraph::get_attr(const std::string &key) const { | |||
| auto v = ToRef<FuncGraphImpl>(impl_).get_attr(key); | |||
| return ToWrapper<Value>(v); | |||
| } | |||
| void FuncGraph::set_attr(const std::string &key, const ValuePtr &value) { | |||
| auto value_impl = ToImpl<ValueImpl>(value); | |||
| ToRef<FuncGraphImpl>(impl_).set_attr(key, value_impl); | |||
| } | |||
| FuncGraphManagerPtr FuncGraph::manager() const { | |||
| auto manager = ToRef<FuncGraphImpl>(impl_).manager(); | |||
| if (manager == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return MakeShared<FuncGraphManager>(manager); | |||
| } | |||
| FuncGraphPtr FuncGraph::Create() { | |||
| auto fg = std::make_shared<FuncGraphImpl>(); | |||
| return ToWrapper<FuncGraph>(fg); | |||
| } | |||
| std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) { | |||
| auto node_impl = ToImpl<AnfNodeImpl>(node); | |||
| if (node_impl == nullptr) { | |||
| return {}; | |||
| } | |||
| auto sorted = mindspore::TopoSort(node_impl); | |||
| return ToWrapperVector<AnfNode>(sorted); | |||
| } | |||
| // FuncGraphManager is not derived from Base, we implement it directly. | |||
| FuncGraphManager::FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl) : impl_(impl) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| } | |||
| bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| return impl_->Replace(ToImpl<AnfNodeImpl>(old_node), ToImpl<AnfNodeImpl>(new_node)); | |||
| } | |||
| void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { | |||
| return impl_->SetEdge(ToImpl<AnfNodeImpl>(node), index, ToImpl<AnfNodeImpl>(value)); | |||
| } | |||
| void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) { | |||
| return impl_->AddEdge(ToImpl<AnfNodeImpl>(node), ToImpl<AnfNodeImpl>(value)); | |||
| } | |||
| std::vector<std::pair<AnfNodePtr, int>> FuncGraphManager::GetUsers(const AnfNodePtr &node) const { | |||
| auto &node_users = impl_->node_users(); | |||
| auto iter = node_users.find(ToImpl<AnfNodeImpl>(node)); | |||
| if (iter == node_users.end()) { | |||
| return {}; | |||
| } | |||
| auto &users_impl = iter->second; | |||
| std::vector<std::pair<AnfNodePtr, int>> users; | |||
| users.reserve(users_impl.size()); | |||
| std::transform(users_impl.begin(), users_impl.end(), std::back_inserter(users), | |||
| [](const auto &user) { return std::make_pair(ToWrapper<AnfNode>(user.first), user.second); }); | |||
| return users; | |||
| } | |||
| FuncGraphManagerPtr FuncGraphManager::Manage(const FuncGraphPtr &func_graph, bool manage) { | |||
| auto fg_impl = ToImpl<FuncGraphImpl>(func_graph); | |||
| auto mgr_impl = mindspore::Manage(fg_impl, manage); | |||
| if (mgr_impl == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return MakeShared<FuncGraphManager>(mgr_impl); | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_ | |||
| #define MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <type_traits> | |||
| #include "mindapi/base/base.h" | |||
| namespace mindspore::api { | |||
| template <typename T, typename U> | |||
| T &ToRef(const std::shared_ptr<U> &ptr) { | |||
| return static_cast<T &>(*ptr); | |||
| } | |||
| template <typename T, typename U, typename = typename std::enable_if_t<std::is_base_of_v<mindspore::Base, T>>, | |||
| typename = typename std::enable_if_t<std::is_base_of_v<Base, U>>> | |||
| std::shared_ptr<T> ToImpl(const SharedPtr<U> &wrapper) { | |||
| if (wrapper == nullptr || wrapper->impl() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return std::dynamic_pointer_cast<T>(wrapper->impl()); | |||
| } | |||
| template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>>> | |||
| SharedPtr<T> ToWrapper(const std::shared_ptr<mindspore::Base> &impl) { | |||
| if (impl == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return MakeShared<T>(impl); | |||
| } | |||
| template <typename T, typename U> | |||
| std::vector<std::shared_ptr<T>> ToImplVector(const U &wrapper_vector) { | |||
| std::vector<std::shared_ptr<T>> impl_vector; | |||
| impl_vector.reserve(wrapper_vector.size()); | |||
| for (auto &wrapper : wrapper_vector) { | |||
| impl_vector.emplace_back(ToImpl<T>(wrapper)); | |||
| } | |||
| return impl_vector; | |||
| } | |||
| template <typename T, typename U> | |||
| std::vector<SharedPtr<T>> ToWrapperVector(const U &impl_vector) { | |||
| std::vector<SharedPtr<T>> wrapper_vector; | |||
| wrapper_vector.reserve(impl_vector.size()); | |||
| for (auto &impl : impl_vector) { | |||
| wrapper_vector.emplace_back(ToWrapper<T>(impl)); | |||
| } | |||
| return wrapper_vector; | |||
| } | |||
| #define MIND_API_BASE_IMPL(current_class, impl_class, base_class) \ | |||
| current_class::current_class(const std::shared_ptr<mindspore::Base> &impl) : base_class(impl) { \ | |||
| if (!impl_->isa<impl_class>()) { \ | |||
| MS_LOG(EXCEPTION) << "Wrong impl " << impl_->type_name() << " for " << #current_class; \ | |||
| } \ | |||
| } \ | |||
| uint32_t current_class::ClassId() { return impl_class::kTypeId; } | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_ | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #define MIND_LOG_NO_MS_LOG | |||
| #include "mindapi/base/logging.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::api { | |||
| static MsLogLevel ToMsLogLevel(LogLevel level) { | |||
| switch (level) { | |||
| case LogLevel::DEBUG: | |||
| return MsLogLevel::DEBUG; | |||
| case LogLevel::INFO: | |||
| return MsLogLevel::INFO; | |||
| case LogLevel::WARNING: | |||
| return MsLogLevel::WARNING; | |||
| case LogLevel::ERROR: | |||
| return MsLogLevel::ERROR; | |||
| case LogLevel::EXCEPTION: | |||
| return MsLogLevel::EXCEPTION; | |||
| default: | |||
| return MsLogLevel::EXCEPTION; | |||
| } | |||
| } | |||
| class LogWriterImpl { | |||
| public: | |||
| LogWriterImpl(LogLevel level, const char *file, int line, const char *func) | |||
| : writer_(LocationInfo(file, line, func), ToMsLogLevel(level), SubModuleId::SM_API) {} | |||
| ~LogWriterImpl() = default; | |||
| void Write(const LogStream &stream) const noexcept { | |||
| mindspore::LogStream log_stream; | |||
| log_stream << stream.stream_.rdbuf(); | |||
| writer_ < log_stream; | |||
| } | |||
| void WriteAndThrow(const LogStream &stream) const __attribute__((noreturn)) { | |||
| mindspore::LogStream log_stream; | |||
| log_stream << stream.stream_.rdbuf(); | |||
| writer_ ^ log_stream; | |||
| } | |||
| private: | |||
| mindspore::LogWriter writer_; | |||
| }; | |||
| LogWriter::LogWriter(LogLevel level, const char *file, int line, const char *func) | |||
| : impl_(std::make_unique<LogWriterImpl>(level, file, line, func)) {} | |||
| LogWriter::~LogWriter() = default; | |||
| void LogWriter::operator<(const LogStream &stream) const noexcept { impl_->Write(stream); } | |||
| void LogWriter::operator^(const LogStream &stream) const { impl_->WriteAndThrow(stream); } | |||
| bool LogWriter::IsEnabled(LogLevel level) { | |||
| auto log_level = ToMsLogLevel(level); | |||
| return IS_OUTPUT_ON(log_level); | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/primitive.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/value.h" | |||
| namespace mindspore::api { | |||
| using ValueImpl = mindspore::Value; | |||
| using PrimitiveImpl = mindspore::Primitive; | |||
| MIND_API_BASE_IMPL(Primitive, PrimitiveImpl, Value); | |||
| Primitive::Primitive(const std::string &name) : Value(std::make_shared<PrimitiveImpl>(name)) {} | |||
| const std::string &Primitive::name() const { return ToRef<PrimitiveImpl>(impl_).name(); } | |||
| Primitive &Primitive::AddAttr(const std::string &name, const ValuePtr &attr) { | |||
| auto value = ToImpl<ValueImpl>(attr); | |||
| ToRef<PrimitiveImpl>(impl_).set_attr(name, value); | |||
| return *this; | |||
| } | |||
| Primitive &Primitive::SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| for (auto &attr : attrs) { | |||
| auto value = ToImpl<ValueImpl>(attr.second); | |||
| ToRef<PrimitiveImpl>(impl_).set_attr(attr.first, value); | |||
| } | |||
| return *this; | |||
| } | |||
| void Primitive::EraseAttr(const std::string &name) { ToRef<PrimitiveImpl>(impl_).EraseAttr(name); } | |||
| ValuePtr Primitive::GetAttr(const std::string &name) const { | |||
| auto v = ToRef<PrimitiveImpl>(impl_).GetAttr(name); | |||
| return ToWrapper<Value>(v); | |||
| } | |||
| bool Primitive::HasAttr(const std::string &name) const { return ToRef<PrimitiveImpl>(impl_).HasAttr(name); } | |||
| std::unordered_map<std::string, ValuePtr> Primitive::attrs() const { | |||
| std::unordered_map<std::string, ValuePtr> attr_map; | |||
| auto &impl_attrs = ToRef<PrimitiveImpl>(impl_).attrs(); | |||
| attr_map.reserve(impl_attrs.size()); | |||
| for (auto &attr : impl_attrs) { | |||
| auto value = ToWrapper<Value>(attr.second); | |||
| attr_map.emplace(attr.first, value); | |||
| } | |||
| return attr_map; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/shape.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "abstract/dshape.h" | |||
| namespace mindspore::api { | |||
| using ShapeImpl = mindspore::abstract::Shape; | |||
| MIND_API_BASE_IMPL(Shape, ShapeImpl, Base); | |||
| const ShapeVector &Shape::shape() const { return ToRef<ShapeImpl>(impl_).shape(); } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "mindapi/ir/tensor.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "ir/tensor.h" | |||
| namespace mindspore::api { | |||
| using TensorImpl = mindspore::tensor::Tensor; | |||
| MIND_API_BASE_IMPL(Tensor, TensorImpl, Value); | |||
| Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : Value(std::make_shared<TensorImpl>(data_type, shape)) {} | |||
| Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len) | |||
| : Value(std::make_shared<TensorImpl>(data_type, shape, data, data_len)) {} | |||
| const ShapeVector &Tensor::shape() const { return ToRef<TensorImpl>(impl_).shape(); } | |||
| void Tensor::set_shape(const ShapeVector &shape) { (void)ToRef<TensorImpl>(impl_).set_shape(shape); } | |||
| TypeId Tensor::data_type() const { return ToRef<TensorImpl>(impl_).data_type(); } | |||
| void Tensor::set_data_type(const TypeId data_type) { (void)ToRef<TensorImpl>(impl_).set_data_type(data_type); } | |||
| const void *Tensor::data() const { return ToRef<TensorImpl>(impl_).data_c(); } | |||
| void *Tensor::data() { return ToRef<TensorImpl>(impl_).data_c(); } | |||
| int Tensor::DataSize() const { return ToRef<TensorImpl>(impl_).DataSize(); } | |||
| size_t Tensor::Size() const { return ToRef<TensorImpl>(impl_).Size(); } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/type.h" | |||
| #include "mindapi/ir/value.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "ir/dtype.h" | |||
| #include "abstract/utils.h" | |||
| namespace mindspore::api { | |||
| using TypeImpl = mindspore::Type; | |||
| MIND_API_BASE_IMPL(Type, TypeImpl, Value); | |||
| TypeId Type::type_id() const { return ToRef<TypeImpl>(impl_).type_id(); } | |||
| TypeId Type::number_type() const { return ToRef<TypeImpl>(impl_).number_type(); } | |||
| TypePtr Type::GetType(TypeId id) { | |||
| auto type_impl = mindspore::TypeIdToType(id); | |||
| return ToWrapper<Type>(type_impl); | |||
| } | |||
| size_t Type::GetSize(TypeId id) { return mindspore::abstract::TypeIdSize(id); } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/utils.h" | |||
| #include "mindapi/src/helper.h" | |||
| #define USE_DEPRECATED_API | |||
| #include "ir/anf.h" | |||
| #include "ir/value.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore::api::utils { | |||
| using ValueImpl = mindspore::Value; | |||
| using FuncGraphImpl = mindspore::FuncGraph; | |||
| MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto fg_impl = ToImpl<FuncGraphImpl>(func_graph); | |||
| Cloner cloner({fg_impl}, false, true, true, std::make_shared<TraceCopy>(), nullptr); | |||
| auto cloned_fg = cloner[fg_impl]; | |||
| return ToWrapper<api::FuncGraph>(cloned_fg); | |||
| } | |||
| int64_t GetPadMode(const api::ValuePtr &value, bool is_upper) { | |||
| int64_t result; | |||
| auto value_impl = ToImpl<ValueImpl>(value); | |||
| CheckAndConvertUtils::GetPadModEnumValue(value_impl, &result, is_upper); | |||
| return result; | |||
| } | |||
| } // namespace mindspore::api::utils | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindapi/ir/value.h" | |||
| #include "mindapi/ir/type.h" | |||
| #include "mindapi/ir/abstract.h" | |||
| #include "mindapi/src/helper.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "ir/value.h" | |||
| #include "ir/scalar.h" | |||
| namespace mindspore::api { | |||
| using ValueImpl = mindspore::Value; | |||
| using ValueSequenceImpl = mindspore::ValueSequeue; // 'Sequeue' is typo. | |||
| using ValueTupleImpl = mindspore::ValueTuple; | |||
| using StringImmImpl = mindspore::StringImm; | |||
| using ScalarImpl = mindspore::Scalar; | |||
| using BoolImmImpl = mindspore::BoolImm; | |||
| using IntegerImmImpl = mindspore::IntergerImm; // 'Interger' is typo. | |||
| using Int64ImmImpl = mindspore::Int64Imm; | |||
| using FloatImmImpl = mindspore::FloatImm; | |||
| using FP32ImmImpl = mindspore::FP32Imm; | |||
| MIND_API_BASE_IMPL(Value, ValueImpl, Base); | |||
| TypePtr Value::type() const { | |||
| auto t = ToRef<ValueImpl>(impl_).type(); | |||
| return ToWrapper<Type>(t); | |||
| } | |||
| AbstractBasePtr Value::ToAbstract() const { | |||
| auto abs = ToRef<ValueImpl>(impl_).ToAbstract(); | |||
| return ToWrapper<AbstractBase>(abs); | |||
| } | |||
| MIND_API_BASE_IMPL(ValueSequence, ValueSequenceImpl, Value); | |||
| std::size_t ValueSequence::size() const { return ToRef<ValueSequenceImpl>(impl_).size(); } | |||
| std::vector<ValuePtr> ValueSequence::value() const { | |||
| auto &elements = ToRef<ValueSequenceImpl>(impl_).value(); | |||
| return ToWrapperVector<Value>(elements); | |||
| } | |||
| MIND_API_BASE_IMPL(ValueTuple, ValueTupleImpl, ValueSequence); | |||
| ValueTuple::ValueTuple(const std::vector<ValuePtr> &elements) | |||
| : ValueSequence(std::make_shared<ValueTupleImpl>(ToImplVector<ValueImpl>(elements))) {} | |||
| MIND_API_BASE_IMPL(StringImm, StringImmImpl, Value); | |||
| StringImm::StringImm(const std::string &str) : Value(std::make_shared<StringImmImpl>(str)) {} | |||
| const std::string &StringImm::value() const { return ToRef<StringImmImpl>(impl_).value(); } | |||
| MIND_API_BASE_IMPL(Scalar, ScalarImpl, Value); | |||
| MIND_API_BASE_IMPL(BoolImm, BoolImmImpl, Scalar); | |||
| BoolImm::BoolImm(bool b) : Scalar(std::make_shared<BoolImmImpl>(b)) {} | |||
| bool BoolImm::value() const { return ToRef<BoolImmImpl>(impl_).value(); } | |||
| MIND_API_BASE_IMPL(IntegerImm, IntegerImmImpl, Scalar); | |||
| MIND_API_BASE_IMPL(Int64Imm, Int64ImmImpl, IntegerImm); | |||
| Int64Imm::Int64Imm(int64_t value) : IntegerImm(std::make_shared<Int64ImmImpl>(value)) {} | |||
| int64_t Int64Imm::value() const { return ToRef<Int64ImmImpl>(impl_).value(); } | |||
| MIND_API_BASE_IMPL(FloatImm, FloatImmImpl, Scalar); | |||
| MIND_API_BASE_IMPL(FP32Imm, FP32ImmImpl, FloatImm); | |||
| FP32Imm::FP32Imm(float value) : FloatImm(std::make_shared<FP32ImmImpl>(value)) {} | |||
| float FP32Imm::value() const { return ToRef<FP32ImmImpl>(impl_).value(); } | |||
| } // namespace mindspore::api | |||
| @@ -141,6 +141,7 @@ enum SubModuleId : int { | |||
| SM_HCCL_ADPT, // Hccl Adapter | |||
| SM_RUNTIME_FRAMEWORK, // Runtime framework | |||
| SM_GE, // GraphEngine | |||
| SM_API, // MindAPI | |||
| NUM_SUBMODUES // number of submodules | |||
| }; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -17,7 +17,6 @@ | |||
| #ifndef MINDSPORE_SHAPE_UTILS_INFO_H_ | |||
| #define MINDSPORE_SHAPE_UTILS_INFO_H_ | |||
| #include <vector> | |||
| using ShapeVector = std::vector<int64_t>; | |||
| #include "mindapi/base/shape_vector.h" | |||
| #endif // MINDSPORE_SHAPE_UTILS_INFO_H_ | |||
| @@ -16,6 +16,8 @@ set(API_IR_HEADER | |||
| ${CORE_DIR}/api/ir/func_graph.h | |||
| ${CORE_DIR}/api/ir/func_graph_manager.h | |||
| ) | |||
| file(GLOB MINDAPI_BASE_HEADER ${CORE_DIR}/mindapi/base/*.h) | |||
| file(GLOB MINDAPI_IR_HEADER ${CORE_DIR}/mindapi/ir/*.h) | |||
| set(BASE_HEADER | |||
| ${CORE_DIR}/base/base.h | |||
| ${CORE_DIR}/base/base_ref.h | |||
| @@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o) | |||
| CFLAGS := -Ofast -std=c++17 \ | |||
| -I . \ | |||
| -I ./msl/runtime \ | |||
| -I ./msl/runtime/include \ | |||
| -I ./msl/runtime/minddata \ | |||
| -I ./msl/tools/third_party/flatbuffers/include | |||
| @@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o) | |||
| CFLAGS := -Ofast -std=c++17 \ | |||
| -I . \ | |||
| -I ./msl/runtime \ | |||
| -I ./msl/runtime/include \ | |||
| -I ./msl/runtime/minddata \ | |||
| -I ./msl/tools/third_party/flatbuffers/include | |||
| @@ -19,6 +19,7 @@ INF_OBJ:=$(INF_SRC:.cc=.o) | |||
| CFLAGS := -Ofast -std=c++17 \ | |||
| -I . \ | |||
| -I ./msl/runtime \ | |||
| -I ./msl/runtime/include \ | |||
| -I ./msl/runtime/minddata \ | |||
| -I ./msl/tools/third_party/flatbuffers/include | |||
| @@ -289,6 +289,9 @@ if(APPLE) | |||
| set(MINDSPORE_LITE_PUB_HDRS_IR_HDRS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type_id.h | |||
| ) | |||
| set(MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../core/mindapi/base/type_id.h | |||
| ) | |||
| add_library(mindspore-lite_static STATIC | |||
| ${LITE_SRC} | |||
| ${MINDSPORE_LITE_PUB_HDRS} | |||
| @@ -423,6 +426,9 @@ if(DEFINED ARCHS) | |||
| FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_IR_HDRS}) | |||
| SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/ir/dtype/) | |||
| ENDFOREACH() | |||
| FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS}) | |||
| SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/mindapi/base/) | |||
| ENDFOREACH() | |||
| target_link_libraries(mindspore-lite_static) | |||
| endif() | |||
| @@ -71,6 +71,7 @@ if(ENABLE_MINDDATA) | |||
| ./fl/*.cc | |||
| ./cxx_api/*.cc | |||
| ./tbe/*.cc | |||
| ./mindapi/*.cc | |||
| ) | |||
| if(NOT ENABLE_SECURITY) | |||
| file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| @@ -0,0 +1,394 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <sstream> | |||
| #include <unordered_map> | |||
| #include "common/common_test.h" | |||
| #include "mindapi/base/logging.h" | |||
| #include "mindapi/ir/func_graph.h" | |||
| #include "mindapi/ir/tensor.h" | |||
| #include "mindapi/ir/utils.h" | |||
| namespace mindspore::api { | |||
| class TestMindApi : public UT::Common { | |||
| public: | |||
| TestMindApi() = default; | |||
| }; | |||
| /// Feature: MindAPI | |||
| /// Description: test basic 'is()' 'cast()' | |||
| /// Expectation: is/cast works correctly. | |||
| TEST_F(TestMindApi, test_base_isa_cast) { | |||
| auto value_node = MakeShared<ValueNode>(MakeValue(0)); | |||
| auto base = MakeShared<Base>(value_node->impl()); | |||
| ASSERT_TRUE(base->isa<Base>()); | |||
| ASSERT_TRUE(base->isa<AnfNode>()); | |||
| ASSERT_TRUE(base->isa<ValueNode>()); | |||
| ASSERT_FALSE(base->isa<AbstractBase>()); | |||
| auto anf_node = base->cast<AnfNodePtr>(); | |||
| ASSERT_TRUE(anf_node != nullptr); | |||
| ASSERT_TRUE(anf_node->impl() == value_node->impl()); | |||
| ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test graph construction. | |||
| /// Expectation: graph is constructed as expected. | |||
| TEST_F(TestMindApi, test_graph_construction) { | |||
| // fg(x) { return myprim(x, 1); } | |||
| auto fg = FuncGraph::Create(); | |||
| auto x = fg->add_parameter(); | |||
| x->set_name("x"); | |||
| auto prim = MakeShared<Primitive>("myprim"); | |||
| auto prim_node = MakeShared<ValueNode>(prim); | |||
| auto value_node = MakeShared<ValueNode>(MakeValue(1)); | |||
| auto cnode = fg->NewCNode({prim_node, x, value_node}); | |||
| fg->set_output(cnode); | |||
| // Now we check the graph. | |||
| ASSERT_EQ(fg->parameters().size(), 1); | |||
| ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>()); | |||
| ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x"); | |||
| auto ret_node = fg->get_return(); | |||
| ASSERT_TRUE(ret_node != nullptr); | |||
| auto output_node = fg->output(); | |||
| ASSERT_TRUE(output_node != nullptr); | |||
| ASSERT_TRUE(output_node->isa<CNode>()); | |||
| auto output_cnode = output_node->cast<CNodePtr>(); | |||
| ASSERT_EQ(output_cnode->inputs().size(), 3); | |||
| ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>()); | |||
| ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>()); | |||
| ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim"); | |||
| ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>()); | |||
| ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x"); | |||
| ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>()); | |||
| ASSERT_EQ(output_cnode->impl(), cnode->impl()); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test value related functions. | |||
| /// Expectation: value related functions work as expected. | |||
| TEST_F(TestMindApi, test_values) { | |||
| int64_t one = 1; | |||
| auto s = MakeValue("hello"); | |||
| auto i = MakeValue(one); | |||
| auto i2 = MakeValue(2); | |||
| auto b = MakeValue(true); | |||
| auto f = MakeValue(3.14f); | |||
| auto seq = MakeValue(std::vector<int64_t>{3, 4, 5}); | |||
| auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"})); | |||
| ASSERT_TRUE(s->isa<StringImm>()); | |||
| ASSERT_TRUE(i->isa<Int64Imm>()); | |||
| ASSERT_TRUE(i2->isa<Int64Imm>()); | |||
| ASSERT_TRUE(b->isa<BoolImm>()); | |||
| ASSERT_TRUE(f->isa<FP32Imm>()); | |||
| ASSERT_TRUE(seq->isa<ValueSequence>()); | |||
| ASSERT_TRUE(seq_str->isa<ValueSequence>()); | |||
| ASSERT_EQ(GetValue<std::string>(s), "hello"); | |||
| ASSERT_EQ(GetValue<int64_t>(i), one); | |||
| ASSERT_EQ(GetValue<int64_t>(i2), 2); | |||
| ASSERT_TRUE(GetValue<bool>(b)); | |||
| ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f); | |||
| ASSERT_EQ(GetValue<std::string>(i), ""); | |||
| ASSERT_EQ(GetValue<int64_t>(s), 0); | |||
| ASSERT_FALSE(GetValue<bool>(s)); | |||
| ASSERT_EQ(GetValue<float>(s), 0.0f); | |||
| auto seq_ptr = seq->cast<ValueSequencePtr>(); | |||
| ASSERT_TRUE(seq_ptr != nullptr); | |||
| ASSERT_EQ(seq_ptr->size(), 3); | |||
| ASSERT_EQ(seq_ptr->value().size(), 3); | |||
| ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>()); | |||
| ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3); | |||
| ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4); | |||
| ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5); | |||
| auto seq_values = GetValue<std::vector<int64_t>>(seq); | |||
| ASSERT_EQ(seq_values.size(), 3); | |||
| ASSERT_EQ(seq_values[0], 3); | |||
| ASSERT_EQ(seq_values[1], 4); | |||
| ASSERT_EQ(seq_values[2], 5); | |||
| auto str_values = GetValue<std::vector<std::string>>(seq_str); | |||
| ASSERT_EQ(str_values.size(), 4); | |||
| ASSERT_EQ(str_values[0], "this"); | |||
| ASSERT_EQ(str_values[1], "is"); | |||
| ASSERT_EQ(str_values[2], "mindspore"); | |||
| ASSERT_EQ(str_values[3], "api"); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test graph manager functions. | |||
| /// Expectation: graph manager functions work as expected. | |||
| TEST_F(TestMindApi, test_func_graph_manager) { | |||
| // fg(x, y) { return myprim(add(x, y), 1); } | |||
| auto fg = FuncGraph::Create(); | |||
| auto x = fg->add_parameter(); | |||
| x->set_name("x"); | |||
| auto y = fg->add_parameter(); | |||
| y->set_name("y"); | |||
| auto add = MakeShared<Primitive>("add"); | |||
| auto add_node = MakeShared<ValueNode>(add); | |||
| auto add_cnode = fg->NewCNode({add_node, x, y}); | |||
| auto prim = MakeShared<Primitive>("myprim"); | |||
| auto prim_node = MakeShared<ValueNode>(prim); | |||
| auto value_node = MakeShared<ValueNode>(MakeValue(1)); | |||
| auto cnode = fg->NewCNode({prim_node, add_cnode, value_node}); | |||
| fg->set_output(cnode); | |||
| auto mgr = FuncGraphManager::Manage(fg); | |||
| ASSERT_TRUE(mgr != nullptr); | |||
| ASSERT_TRUE(fg->manager() != nullptr); | |||
| ASSERT_EQ(fg->manager()->impl(), mgr->impl()); | |||
| ASSERT_EQ(fg->manager(), mgr); | |||
| ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl()); | |||
| mgr->Replace(add_cnode, x); | |||
| ASSERT_EQ(cnode->input(1)->impl(), x->impl()); | |||
| mgr->SetEdge(cnode, 1, y); | |||
| ASSERT_EQ(cnode->input(1)->impl(), y->impl()); | |||
| mgr->AddEdge(cnode, x); | |||
| ASSERT_EQ(cnode->size(), 4); | |||
| ASSERT_EQ(cnode->input(3)->impl(), x->impl()); | |||
| auto users = mgr->GetUsers(value_node); | |||
| ASSERT_EQ(users.size(), 1); | |||
| ASSERT_EQ(users[0].first, cnode); | |||
| ASSERT_EQ(users[0].second, 2); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test value node utils. | |||
| /// Expectation: value node utils work as expected. | |||
| TEST_F(TestMindApi, test_value_node_utils) { | |||
| auto fg = FuncGraph::Create(); | |||
| auto fg_node = MakeShared<ValueNode>(fg); | |||
| auto prim = MakeShared<Primitive>("myprim"); | |||
| auto prim_node = MakeShared<ValueNode>(prim); | |||
| auto one = MakeShared<ValueNode>(MakeValue(1)); | |||
| auto cnode = fg->NewCNode({fg_node, prim_node, one}); | |||
| ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr); | |||
| auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0)); | |||
| ASSERT_TRUE(fg1 != nullptr); | |||
| ASSERT_TRUE(fg1->isa<FuncGraph>()); | |||
| auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1)); | |||
| ASSERT_TRUE(prim1 != nullptr); | |||
| ASSERT_TRUE(prim1->isa<Primitive>()); | |||
| auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2)); | |||
| ASSERT_TRUE(imm != nullptr); | |||
| ASSERT_TRUE(imm->isa<Int64Imm>()); | |||
| ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1); | |||
| auto value = GetValueNode(cnode->input(2)); | |||
| ASSERT_TRUE(value != nullptr); | |||
| ASSERT_EQ(GetValue<int64_t>(value), 1); | |||
| ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr); | |||
| ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr); | |||
| ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr); | |||
| // Test NewValueNode. | |||
| auto int_node = NewValueNode(1); | |||
| auto bool_node = NewValueNode(true); | |||
| auto float_node = NewValueNode(1.23f); | |||
| auto str_node = NewValueNode("hello"); | |||
| ASSERT_TRUE(int_node->value()->isa<Int64Imm>()); | |||
| ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1); | |||
| ASSERT_TRUE(bool_node->value()->isa<BoolImm>()); | |||
| ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value()); | |||
| ASSERT_TRUE(float_node->value()->isa<FP32Imm>()); | |||
| ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f); | |||
| ASSERT_TRUE(str_node->value()->isa<StringImm>()); | |||
| ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello"); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test SharedPtr. | |||
| /// Expectation: SharedPtr work as expected. | |||
| TEST_F(TestMindApi, test_object_ptr) { | |||
| auto fg = FuncGraph::Create(); | |||
| auto fg_node = MakeShared<ValueNode>(fg); | |||
| auto prim = MakeShared<Primitive>("myprim"); | |||
| auto prim_node = MakeShared<ValueNode>(prim); | |||
| auto one = MakeShared<ValueNode>(MakeValue(1)); | |||
| auto cnode = fg->NewCNode({fg_node, prim_node, one}); | |||
| ASSERT_TRUE(fg != nullptr); | |||
| ASSERT_FALSE(!fg); | |||
| ASSERT_TRUE(fg ? true : false); | |||
| ASSERT_TRUE((*cnode).input(0) == fg_node); | |||
| ASSERT_TRUE(cnode->input(0) == fg_node); | |||
| ASSERT_TRUE(cnode.get()->input(0) == fg_node); | |||
| ASSERT_EQ(cnode->input(0), fg_node); | |||
| ASSERT_EQ(cnode->input(1), prim_node); | |||
| ASSERT_EQ(cnode->input(2), one); | |||
| ASSERT_TRUE(cnode->input(0) != fg); | |||
| AnfNodePtr p = fg_node; | |||
| ASSERT_TRUE(p == fg_node); | |||
| ASSERT_TRUE(p->isa<ValueNode>()); | |||
| ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr); | |||
| ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node); | |||
| p = cnode; | |||
| ASSERT_TRUE(p == cnode); | |||
| ASSERT_TRUE(p->isa<CNode>()); | |||
| ASSERT_TRUE(p->cast<CNodePtr>() != nullptr); | |||
| ASSERT_TRUE(p->cast<CNodePtr>() == cnode); | |||
| ASSERT_TRUE(p.get() == cnode.get()); | |||
| ASSERT_TRUE(p != nullptr); | |||
| ASSERT_FALSE(p == nullptr); | |||
| ASSERT_TRUE(p > nullptr); | |||
| ASSERT_FALSE(p < nullptr); | |||
| ASSERT_TRUE(p >= nullptr); | |||
| ASSERT_FALSE(p <= nullptr); | |||
| ASSERT_TRUE(nullptr != p); | |||
| ASSERT_FALSE(nullptr == p); | |||
| ASSERT_TRUE(nullptr < p); | |||
| ASSERT_FALSE(nullptr > p); | |||
| ASSERT_TRUE(nullptr <= p); | |||
| ASSERT_FALSE(nullptr >= p); | |||
| AnfNodePtr q = fg_node; | |||
| ASSERT_TRUE(p != q); | |||
| ASSERT_TRUE(p > q); | |||
| if (p.get()->impl() > q.get()->impl()) { | |||
| ASSERT_TRUE(p > q); | |||
| ASSERT_TRUE(p >= q); | |||
| ASSERT_TRUE(q < p); | |||
| ASSERT_TRUE(q <= p); | |||
| } else { | |||
| ASSERT_TRUE(p < q); | |||
| ASSERT_TRUE(p <= q); | |||
| ASSERT_TRUE(q > p); | |||
| ASSERT_TRUE(q >= p); | |||
| } | |||
| std::stringstream ss1; | |||
| std::stringstream ss2; | |||
| ss1 << p; | |||
| ss2 << cnode.get()->impl().get(); | |||
| ASSERT_EQ(ss1.str(), ss2.str()); | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> mymap; | |||
| mymap.emplace(p, q); | |||
| mymap.emplace(q, p); | |||
| ASSERT_TRUE(mymap.find(p) != mymap.end()); | |||
| ASSERT_TRUE(mymap.find(q) != mymap.end()); | |||
| ASSERT_TRUE(mymap[p] == q); | |||
| ASSERT_TRUE(mymap[q] == p); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test Tensor API. | |||
| /// Expectation: Tensor API work as expected. | |||
| TEST_F(TestMindApi, test_tensor_api) { | |||
| ShapeVector shape{1, 2, 3}; | |||
| auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape); | |||
| ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32); | |||
| ASSERT_EQ(tensor->shape(), shape); | |||
| ASSERT_EQ(tensor->DataSize(), 6); | |||
| ASSERT_EQ(tensor->Size(), 24); | |||
| ShapeVector shape2{2, 3}; | |||
| tensor->set_data_type(kNumberTypeInt32); | |||
| tensor->set_shape(shape2); | |||
| ASSERT_EQ(tensor->data_type(), kNumberTypeInt32); | |||
| ASSERT_EQ(tensor->shape(), shape2); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test utils API. | |||
| /// Expectation: Tensor API work as expected. | |||
| TEST_F(TestMindApi, test_api_utils) { | |||
| // Test utils::isa, utils::cast. | |||
| auto anf_node = NewValueNode("hello"); | |||
| ASSERT_TRUE(utils::isa<AnfNode>(anf_node)); | |||
| ASSERT_FALSE(utils::isa<AbstractBase>(anf_node)); | |||
| ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr); | |||
| ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr); | |||
| anf_node = nullptr; | |||
| ASSERT_FALSE(utils::isa<AnfNode>(anf_node)); | |||
| ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr); | |||
| // Test clone graph. | |||
| auto fg = FuncGraph::Create(); | |||
| auto x = fg->add_parameter(); | |||
| x->set_name("x"); | |||
| auto y = fg->add_parameter(); | |||
| y->set_name("y"); | |||
| auto add = MakeShared<Primitive>("add"); | |||
| auto add_node = MakeShared<ValueNode>(add); | |||
| auto add_cnode = fg->NewCNode({add_node, x, y}); | |||
| auto prim = MakeShared<Primitive>("myprim"); | |||
| auto prim_node = MakeShared<ValueNode>(prim); | |||
| auto value_node = MakeShared<ValueNode>(MakeValue(1)); | |||
| auto cnode = fg->NewCNode({prim_node, add_cnode, value_node}); | |||
| fg->set_output(cnode); | |||
| auto cloned_fg = utils::CloneGraph(fg); | |||
| ASSERT_TRUE(cloned_fg != nullptr); | |||
| ASSERT_EQ(cloned_fg->parameters().size(), 2); | |||
| auto new_output = cloned_fg->output(); | |||
| ASSERT_TRUE(new_output != nullptr); | |||
| ASSERT_TRUE(new_output->isa<CNode>()); | |||
| ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size()); | |||
| ASSERT_TRUE(new_output != cnode); | |||
| ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode); | |||
| // Test get pad mode. | |||
| auto pm_lower = MakeValue("pad"); | |||
| auto pm_upper = MakeValue("PAD"); | |||
| ASSERT_EQ(utils::GetPadMode(pm_lower), 0); | |||
| ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0); | |||
| ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0); | |||
| } | |||
| /// Feature: MindAPI | |||
| /// Description: test logging API. | |||
| /// Expectation: logging work as expected. | |||
| TEST_F(TestMindApi, test_api_logging) { | |||
| MS_LOG(DEBUG) << "hello debug"; | |||
| MS_LOG(INFO) << "hello info"; | |||
| MS_LOG(WARNING) << "hello warning"; | |||
| MS_LOG(ERROR) << "hello error"; | |||
| try { | |||
| MS_LOG(EXCEPTION) << "hello exception"; | |||
| ASSERT_TRUE(false); | |||
| } catch (...) { | |||
| } | |||
| ASSERT_TRUE(true); | |||
| } | |||
| } // namespace mindspore::api | |||