Browse Source

Provide new mindspore core API classes

1. namespace is mindspore::api;
2. API header files located in mindspore/core/mindapi;
3. We use pimpl pattern to provide a wrapper layer for api;
4. Check mindapi_test.cc for usage examples.
tags/v1.6.0
He Wei 4 years ago
parent
commit
d2bb6303b7
46 changed files with 3087 additions and 99 deletions
  1. +12
    -0
      cmake/package_lite.cmake
  2. +1
    -0
      mindspore/core/CMakeLists.txt
  3. +9
    -3
      mindspore/core/api/ir/func_graph.h
  4. +9
    -4
      mindspore/core/api/ir/func_graph_manager.h
  5. +1
    -0
      mindspore/core/gvar/log_adapter_common.cc
  6. +3
    -84
      mindspore/core/ir/dtype/type_id.h
  7. +2
    -2
      mindspore/core/ir/func_graph.h
  8. +4
    -3
      mindspore/core/ir/manager.h
  9. +89
    -0
      mindspore/core/mindapi/base/base.h
  10. +104
    -0
      mindspore/core/mindapi/base/logging.h
  11. +30
    -0
      mindspore/core/mindapi/base/macros.h
  12. +25
    -0
      mindspore/core/mindapi/base/shape_vector.h
  13. +180
    -0
      mindspore/core/mindapi/base/shared_ptr.h
  14. +104
    -0
      mindspore/core/mindapi/base/type_id.h
  15. +42
    -0
      mindspore/core/mindapi/base/type_traits.h
  16. +101
    -0
      mindspore/core/mindapi/ir/abstract.h
  17. +232
    -0
      mindspore/core/mindapi/ir/anf.h
  18. +50
    -0
      mindspore/core/mindapi/ir/common.h
  19. +193
    -0
      mindspore/core/mindapi/ir/func_graph.h
  20. +79
    -0
      mindspore/core/mindapi/ir/primitive.h
  21. +36
    -0
      mindspore/core/mindapi/ir/shape.h
  22. +93
    -0
      mindspore/core/mindapi/ir/tensor.h
  23. +56
    -0
      mindspore/core/mindapi/ir/type.h
  24. +70
    -0
      mindspore/core/mindapi/ir/utils.h
  25. +270
    -0
      mindspore/core/mindapi/ir/value.h
  26. +84
    -0
      mindspore/core/mindapi/src/abstract.cc
  27. +135
    -0
      mindspore/core/mindapi/src/anf.cc
  28. +28
    -0
      mindspore/core/mindapi/src/base.cc
  29. +170
    -0
      mindspore/core/mindapi/src/func_graph.cc
  30. +76
    -0
      mindspore/core/mindapi/src/helper.h
  31. +75
    -0
      mindspore/core/mindapi/src/logging.cc
  32. +65
    -0
      mindspore/core/mindapi/src/primitive.cc
  33. +27
    -0
      mindspore/core/mindapi/src/shape.cc
  34. +47
    -0
      mindspore/core/mindapi/src/tensor.cc
  35. +39
    -0
      mindspore/core/mindapi/src/type.cc
  36. +43
    -0
      mindspore/core/mindapi/src/utils.cc
  37. +94
    -0
      mindspore/core/mindapi/src/value.cc
  38. +1
    -0
      mindspore/core/utils/log_adapter.h
  39. +2
    -3
      mindspore/core/utils/shape_utils.h
  40. +2
    -0
      mindspore/lite/cmake/file_list.cmake
  41. +1
    -0
      mindspore/lite/examples/train_lenet/Makefile
  42. +1
    -0
      mindspore/lite/examples/transfer_learning/Makefile
  43. +1
    -0
      mindspore/lite/examples/unified_api/Makefile
  44. +6
    -0
      mindspore/lite/src/CMakeLists.txt
  45. +1
    -0
      tests/ut/cpp/CMakeLists.txt
  46. +394
    -0
      tests/ut/cpp/mindapi/mindapi_test.cc

+ 12
- 0
cmake/package_lite.cmake View File

@@ -221,6 +221,8 @@ if(PLATFORM_ARM64)
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@@ -298,6 +300,8 @@ elseif(PLATFORM_ARM32)
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@@ -365,6 +369,8 @@ elseif(WIN32)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@@ -409,6 +415,8 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/core/mindapi/base/type_id.h DESTINATION ${RUNTIME_INC_DIR}/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
@@ -430,6 +438,10 @@ else()
PATTERN "train*" EXCLUDE PATTERN "delegate.h" EXCLUDE PATTERN "lite_session.h" EXCLUDE)
install(FILES ${API_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/api
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${MINDAPI_BASE_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/base
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${MINDAPI_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/mindapi/ir
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${ABSTRACT_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${API_IR_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api/ir


+ 1
- 0
mindspore/core/CMakeLists.txt View File

@@ -23,6 +23,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"ir/*.cc"
"utils/*.cc"
"load_mindir/*.cc"
"mindapi/src/*.cc"
)

if(ENABLE_SECURITY)


+ 9
- 3
mindspore/core/api/ir/func_graph.h View File

@@ -24,8 +24,7 @@
#include "utils/visible.h"
#include "api/ir/func_graph_manager.h"

namespace mindspore::api {

namespace mindspore::deprecated::api {
/// \brief FuncGraph defines interface for a function graph.
class MS_CORE_API FuncGraph {
public:
@@ -147,5 +146,12 @@ class MS_CORE_API FuncGraph {
/// \return The function graph if the input is value node that holds the graph, nullptr otherwise.
static FuncGraphPtr GetFuncGraphFromAnfNode(const AnfNodePtr &input);
};
} // namespace mindspore::api

#ifndef USE_DEPRECATED_API
#define USE_DEPRECATED_API
namespace mindspore {
namespace api = deprecated::api;
}
#endif
} // namespace mindspore::deprecated::api
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_H_

+ 9
- 4
mindspore/core/api/ir/func_graph_manager.h View File

@@ -26,8 +26,7 @@
#include "utils/hashing.h"
#include "ir/anf.h"

namespace mindspore::api {

namespace mindspore::deprecated::api {
class FuncGraph;
using FuncGraphPtr = std::shared_ptr<FuncGraph>;

@@ -80,7 +79,13 @@ class MS_CORE_API FuncGraphManager {
/// \return The manager that manages the given function graph.
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
};

} // namespace mindspore::api
} // namespace mindspore::deprecated::api

#ifndef USE_DEPRECATED_API
#define USE_DEPRECATED_API
namespace mindspore {
namespace api = deprecated::api;
}
#endif

#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_MANAGER_H_

+ 1
- 0
mindspore/core/gvar/log_adapter_common.cc View File

@@ -52,6 +52,7 @@ static const std::vector<std::string> sub_module_names = {
"HCCL_ADPT", // SM_HCCL_ADPT
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK
"GE", // SM_GE
"API", // SM_API
};

const std::string GetSubModuleName(SubModuleId module_id) { return sub_module_names[(module_id % NUM_SUBMODUES)]; }


+ 3
- 84
mindspore/core/ir/dtype/type_id.h View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,87 +19,6 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
#define MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_

namespace mindspore {
//
// Supported meta type
//
enum TypeId : int {
kTypeUnknown = 0,
kMetaTypeBegin = kTypeUnknown,
kMetaTypeType, // Type
kMetaTypeAnything,
kMetaTypeObject,
kMetaTypeTypeType, // TypeType
kMetaTypeProblem,
kMetaTypeExternal,
kMetaTypeNone,
kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd,
//
// Object types
//
kObjectTypeBegin = kMetaTypeEnd,
kObjectTypeNumber,
kObjectTypeString,
kObjectTypeList,
kObjectTypeTuple,
kObjectTypeSlice,
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeRowTensorType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,
kObjectTypeFunction,
kObjectTypeJTagged,
kObjectTypeSymbolicKeyType,
kObjectTypeEnvType,
kObjectTypeRefKey,
kObjectTypeRef,
kObjectTypeEnd,
//
// Number Types
//
kNumberTypeBegin = kObjectTypeEnd,
kNumberTypeBool,
kNumberTypeInt,
kNumberTypeInt8,
kNumberTypeInt16,
kNumberTypeInt32,
kNumberTypeInt64,
kNumberTypeUInt,
kNumberTypeUInt8,
kNumberTypeUInt16,
kNumberTypeUInt32,
kNumberTypeUInt64,
kNumberTypeFloat,
kNumberTypeFloat16,
kNumberTypeFloat32,
kNumberTypeFloat64,
kNumberTypeComplex,
kNumberTypeComplex64,
kNumberTypeComplex128,
kNumberTypeInt4,
kNumberTypeGLUInt,
kNumberTypeEnd,
//
// Monad Types
//
kMonadTypeBegin = kNumberTypeEnd,
kObjectTypeMonad,
kObjectTypeUMonad,
kObjectTypeIOMonad,
kMonadTypeEnd,
//
// Sparse Types
//
// Sparse types is placed at the end of enum,
// in order to keep fit with the type of existing model on the lite side.
kSparseTypeBegin = kMonadTypeEnd,
kObjectTypeCSRTensorType,
kSparseTypeEnd
};
} // namespace mindspore
#include "mindapi/base/type_id.h"

#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_

+ 2
- 2
mindspore/core/ir/func_graph.h View File

@@ -153,7 +153,7 @@ class FuncGraphBase : public Value {
MS_DECLARE_PARENT(FuncGraphBase, Value);
};

class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
public:
using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;

@@ -265,7 +265,7 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
FuncGraphManagerPtr manager() const { return manager_.lock(); }
void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }

api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
deprecated::api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }

std::string ToString() const override;
GraphDebugInfoPtr debug_info();


+ 4
- 3
mindspore/core/ir/manager.h View File

@@ -55,9 +55,9 @@ class FuncGraphTransaction;
class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;

using AnfNodeIndexSet = api::AnfNodeIndexSet;
using AnfNodeIndexSet = deprecated::api::AnfNodeIndexSet;
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
using NodeUsersMap = api::NodeUsersMap;
using NodeUsersMap = deprecated::api::NodeUsersMap;
using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;

@@ -277,7 +277,8 @@ class FuncGraphJTotalComputer final : public DepComputer {
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
};

class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>,
public deprecated::api::FuncGraphManager {
public:
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager() {


+ 89
- 0
mindspore/core/mindapi/base/base.h View File

@@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_BASE_H_
#define MINDSPORE_CORE_MINDAPI_BASE_BASE_H_

#include <cstdint>
#include <string>
#include <memory>
#include "mindapi/base/macros.h"
#include "mindapi/base/type_traits.h"
#include "mindapi/base/shared_ptr.h"

namespace mindspore {
class Base;
}

namespace mindspore::api {
/// \brief Base is the base class of many api classes, which provides basic interfaces.
class MIND_API Base {
public:
/// \brief Create an instance from the given implementation object.
///
/// \param[in] impl The shared_ptr to the implementation object.
explicit Base(const std::shared_ptr<mindspore::Base> &impl);

/// \brief Destructor of Base.
virtual ~Base() = default;

/// \brief Get the id of this class.
///
/// \return The id of this class.
static uint32_t ClassId();

/// \brief Get the shared_ptr to the underly implementation object.
///
/// \return The shared_ptr to the underly implementation object.
const std::shared_ptr<mindspore::Base> &impl() const { return impl_; }

/// \brief Get the string representation of this object.
///
/// \return The string representation.
std::string ToString() const;

/// \brief Check whether this object is an instance of the given class.
///
/// \return True if this object is an instance of the given class, false otherwise.
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>>
inline bool isa() const {
return IsFromClassId(T::ClassId());
}

/// \brief Cast this object to a pointer with the given pointer class.
///
/// \return A non-null pointer if cast success, nullptr otherwise.
template <typename T, typename U = typename std::enable_if_t<is_wrapper_ptr<T>::value, typename T::element_type>>
inline T cast() {
if (isa<U>()) {
return MakeShared<U>(impl_);
}
return nullptr;
}

protected:
bool IsFromClassId(uint32_t class_id) const;
const std::shared_ptr<mindspore::Base> impl_;
};

#define MIND_API_BASE_MEMBER(current_class) \
explicit current_class(const std::shared_ptr<mindspore::Base> &impl); \
~current_class() override = default; \
static uint32_t ClassId()

using BasePtr = SharedPtr<Base>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_BASE_BASE_H_

+ 104
- 0
mindspore/core/mindapi/base/logging.h View File

@@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_
#define MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_

#include <cstdint>
#include <memory>
#include <sstream>
#include <utility>
#include "mindapi/base/macros.h"

namespace mindspore::api {
enum class LogLevel : uint8_t { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION };

class LogWriterImpl;

/// \brief LogStream represents a stream to write log messages.
/// This class is not expected for directly use, use MS_LOG instead.
class LogStream {
public:
/// \brief Write log message to this LogStream.
///
/// \param[in] value The object to be written.
template <typename T>
LogStream &operator<<(T &&value) noexcept {
(void)stream_.operator<<(std::forward<T>(value));
return *this;
}

private:
friend class LogWriterImpl;
std::stringstream stream_;
};

/// \brief LogWriter defines interface for log message output.
/// This class is not expected for directly use, use MS_LOG instead.
class MIND_API LogWriter {
public:
/// \brief Create a LogWriter with the given log level, file name, line number and function name.
///
/// \param[in] level The log level.
/// \param[in] file The file name.
/// \param[in] line The line number.
/// \param[in] func The function name.
LogWriter(LogLevel level, const char *file, int line, const char *func);

/// \brief Destructor for LogWriter.
~LogWriter();

/// \brief Output log message from the input log stream.
///
/// \param[in] stream The input log stream.
void operator<(const LogStream &stream) const noexcept;

/// \brief Output log message from the input log stream and then throw exception.
///
/// \param[in] stream The input log stream.
void operator^(const LogStream &stream) const __attribute__((noreturn));

/// \brief Check whether the given log level is enabled or not.
///
/// \return True if the log level is enabled, false otherwise.
static bool IsEnabled(LogLevel level);

private:
std::unique_ptr<LogWriterImpl> impl_;
};

#define MIND_LOG_STREAM mindspore::api::LogStream()
#define MIND_LOG_WRITER mindspore::api::LogWriter
#define MIND_LOG_LEVEL(L) mindspore::api::LogLevel::L

#define MIND_LOG_THROW(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) ^ MIND_LOG_STREAM
#define MIND_LOG_WRITE(L) MIND_LOG_WRITER(MIND_LOG_LEVEL(L), __FILE__, __LINE__, __FUNCTION__) < MIND_LOG_STREAM
#define MIND_LOG_IF(L) \
if (MIND_LOG_WRITER::IsEnabled(MIND_LOG_LEVEL(L))) MIND_LOG_WRITE(L)

#define MIND_LOG_DEBUG MIND_LOG_IF(DEBUG)
#define MIND_LOG_INFO MIND_LOG_IF(INFO)
#define MIND_LOG_WARNING MIND_LOG_IF(WARNING)
#define MIND_LOG_ERROR MIND_LOG_IF(ERROR)
#define MIND_LOG_EXCEPTION MIND_LOG_THROW(EXCEPTION)
#define MIND_LOG(level) MIND_LOG_##level

#if !defined(MIND_LOG_NO_MS_LOG) && !defined(MS_LOG)
#define MS_LOG(level) MIND_LOG_##level
#endif
} // namespace mindspore::api

#endif // MINDSPORE_CORE_MINDAPI_BASE_LOGGING_H_

+ 30
- 0
mindspore/core/mindapi/base/macros.h View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_
#define MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_

#if (defined(_WIN32) || defined(__WIN32__) || defined(WIN32) || defined(__CYGWIN__))
#ifdef BUILDING_DLL
#define MIND_API __declspec(dllexport)
#else
#define MIND_API __declspec(dllimport)
#endif
#else
#define MIND_API __attribute__((visibility("default")))
#endif

#endif // MINDSPORE_CORE_MINDAPI_BASE_MACROS_H_

+ 25
- 0
mindspore/core/mindapi/base/shape_vector.h View File

@@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_
#define MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_

#include <cstdint>
#include <vector>

using ShapeVector = std::vector<int64_t>;

#endif // MINDSPORE_CORE_MINDAPI_BASE_SHAPE_VECTOR_H_

+ 180
- 0
mindspore/core/mindapi/base/shared_ptr.h View File

@@ -0,0 +1,180 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_
#define MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_

#include <cstdint>
#include <memory>
#include <utility>
#include <ostream>
#include <functional>

namespace mindspore::api {
/// \brief SharedPtr wraps a std::shared_ptr and provides wrapper functions according the underlying implementation.
template <typename T>
class SharedPtr {
public:
using element_type = T;
constexpr SharedPtr() noexcept = default;
constexpr SharedPtr(std::nullptr_t) noexcept : SharedPtr() {} // NOLINT
template <typename U>
explicit SharedPtr(std::shared_ptr<U> &&ptr) : ptr_(std::move(ptr)) {}
template <typename U>
SharedPtr(const SharedPtr<U> &other) : ptr_(other.ptr_) {}
template <typename U>
SharedPtr(SharedPtr<U> &&other) : ptr_(std::move(other.ptr_)) {}
template <typename U>
SharedPtr &operator=(const SharedPtr<U> &other) {
ptr_ = other.ptr_;
return *this;
}
template <typename U>
SharedPtr &operator=(SharedPtr<U> &&other) {
ptr_ = std::move(other.ptr_);
return *this;
}
~SharedPtr() = default;

std::uintptr_t addr() const { return (ptr_ == nullptr) ? 0 : reinterpret_cast<std::uintptr_t>(ptr_->impl().get()); }
element_type &operator*() const noexcept { return *ptr_; }
element_type *operator->() const noexcept { return ptr_.get(); }
element_type *get() const noexcept { return ptr_.get(); }
explicit operator bool() const { return addr() != 0; }

private:
template <typename U>
friend class SharedPtr;
std::shared_ptr<element_type> ptr_;
};

template <typename T, typename U>
inline bool operator==(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() == b.addr();
}

template <typename T>
inline bool operator==(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() == 0;
}

template <typename T>
inline bool operator==(std::nullptr_t, const SharedPtr<T> &a) noexcept {
return a.addr() == 0;
}

template <typename T, typename U>
inline bool operator!=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() != b.addr();
}

template <typename T>
inline bool operator!=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() != 0;
}

template <typename T>
inline bool operator!=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
return a.addr() != 0;
}

template <typename T, typename U>
inline bool operator<(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() < b.addr();
}

template <typename T>
inline bool operator<(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() < 0;
}

template <typename T>
inline bool operator<(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr < ptr' is false only when ptr is nullptr.
return a.addr() != 0;
}

template <typename T, typename U>
inline bool operator>(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() > b.addr();
}

template <typename T>
inline bool operator>(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() > 0;
}

template <typename T>
inline bool operator>(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr > ptr' is always false.
return false;
}

template <typename T, typename U>
inline bool operator<=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() <= b.addr();
}

template <typename T>
inline bool operator<=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() <= 0;
}

template <typename T>
inline bool operator<=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr <= ptr' is always true.
return true;
}

template <typename T, typename U>
inline bool operator>=(const SharedPtr<T> &a, const SharedPtr<U> &b) noexcept {
return a.addr() >= b.addr();
}

template <typename T>
inline bool operator>=(const SharedPtr<T> &a, std::nullptr_t) noexcept {
return a.addr() >= 0;
}

template <typename T>
inline bool operator>=(std::nullptr_t, const SharedPtr<T> &a) noexcept {
// 'nullptr >= ptr' is true only when ptr is nullptr.
return a.addr() == 0;
}

template <typename T, typename U, typename V>
inline std::basic_ostream<U, V> &operator<<(std::basic_ostream<U, V> &os, const SharedPtr<T> &a) {
return (os << reinterpret_cast<void *>(a.addr()));
}

/// \brief Constructs an object of type T and wraps it in a SharedPtr.
///
/// \param[in] args The parameter list for the constructor of T.
template <typename T, typename... Args>
inline SharedPtr<T> MakeShared(Args &&... args) {
auto ptr = std::make_shared<T>(std::forward<Args>(args)...);
return SharedPtr<T>(std::move(ptr));
}
} // namespace mindspore::api

namespace std {
template <typename T>
struct hash<mindspore::api::SharedPtr<T>> {
size_t operator()(const mindspore::api::SharedPtr<T> &ptr) const noexcept { return static_cast<size_t>(ptr.addr()); }
};
} // namespace std

#endif // MINDSPORE_CORE_MINDAPI_BASE_SHARED_PTR_H_

+ 104
- 0
mindspore/core/mindapi/base/type_id.h View File

@@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_

namespace mindspore {
/// \brief TypeId defines data type identifiers.
enum TypeId : int {
kTypeUnknown = 0,
//
// Meta types.
//
kMetaTypeBegin = kTypeUnknown,
kMetaTypeType, // Type
kMetaTypeAnything,
kMetaTypeObject,
kMetaTypeTypeType, // TypeType
kMetaTypeProblem,
kMetaTypeExternal,
kMetaTypeNone,
kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd,
//
// Object types
//
kObjectTypeBegin = kMetaTypeEnd,
kObjectTypeNumber,
kObjectTypeString,
kObjectTypeList,
kObjectTypeTuple,
kObjectTypeSlice,
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeRowTensorType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,
kObjectTypeFunction,
kObjectTypeJTagged,
kObjectTypeSymbolicKeyType,
kObjectTypeEnvType,
kObjectTypeRefKey,
kObjectTypeRef,
kObjectTypeEnd,
//
// Number Types
//
kNumberTypeBegin = kObjectTypeEnd,
kNumberTypeBool,
kNumberTypeInt,
kNumberTypeInt8,
kNumberTypeInt16,
kNumberTypeInt32,
kNumberTypeInt64,
kNumberTypeUInt,
kNumberTypeUInt8,
kNumberTypeUInt16,
kNumberTypeUInt32,
kNumberTypeUInt64,
kNumberTypeFloat,
kNumberTypeFloat16,
kNumberTypeFloat32,
kNumberTypeFloat64,
kNumberTypeComplex,
kNumberTypeComplex64,
kNumberTypeComplex128,
kNumberTypeInt4,
kNumberTypeGLUInt,
kNumberTypeEnd,
//
// Monad Types
//
kMonadTypeBegin = kNumberTypeEnd,
kObjectTypeMonad,
kObjectTypeUMonad,
kObjectTypeIOMonad,
kMonadTypeEnd,
//
// Sparse Types
//
// Sparse types is placed at the end of enum,
// in order to keep fit with the type of existing model on the lite side.
kSparseTypeBegin = kMonadTypeEnd,
kObjectTypeCSRTensorType,
kSparseTypeEnd
};
} // namespace mindspore
#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_

+ 42
- 0
mindspore/core/mindapi/base/type_traits.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_

#include <vector>
#include <memory>
#include <type_traits>
#include "mindapi/base/shared_ptr.h"

namespace mindspore::api {
template <typename T>
struct is_wrapper_ptr : public std::false_type {};
template <typename T>
struct is_wrapper_ptr<SharedPtr<T>> : public std::true_type {};

template <typename T>
struct is_shared_ptr : public std::false_type {};
template <typename T>
struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};

template <typename T>
struct is_vector : public std::false_type {};
template <typename T, typename A>
struct is_vector<std::vector<T, A>> : public std::true_type {};
} // namespace mindspore::api

#endif // MINDSPORE_CORE_MINDAPI_BASE_TYPE_TRAITS_H_

+ 101
- 0
mindspore/core/mindapi/ir/abstract.h View File

@@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
#define MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_

#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/shape.h"
#include "mindapi/ir/type.h"
#include "mindapi/ir/value.h"

namespace mindspore::api {
/// \brief AbstractBase defines base interfaces for abstract of an anf node.
class MIND_API AbstractBase : public Base {
public:
MIND_API_BASE_MEMBER(AbstractBase);

/// \brief Clone an abstract from this abstract.
///
/// \return A pointer to the cloned abstract.
AbstractBasePtr Clone() const;

/// \brief Get the abstract type.
///
/// \return A pointer to the Type.
TypePtr type() const;

/// \brief Get the abstract value.
///
/// \return A pointer to the Value.
ValuePtr value() const;

/// \brief Set the type for this abstract.
///
/// \param[in] type The type to be set.
void set_type(const TypePtr &type);

/// \brief Set the value for this abstract.
///
/// \param[in] value The value to be set.
void set_value(const ValuePtr &value);
};

/// \brief AbstractTensor describes a tensor's type, shape and value.
class MIND_API AbstractTensor : public AbstractBase {
public:
MIND_API_BASE_MEMBER(AbstractTensor);

/// \brief Create AbstractTensor from the given type and shape.
///
/// \param[in] type The data type id of the tensor.
/// \param[in] shape The shape of the tensor.
AbstractTensor(TypeId type, const ShapeVector &shape);

/// \brief Get the element abstract.
///
/// \return A pointer to the element abstract.
AbstractBasePtr element() const;

/// \brief Get the shape of the abstract.
///
/// \return A pointer to the shape.
ShapePtr shape() const;
};

using AbstractTensorPtr = SharedPtr<AbstractTensor>;

/// \brief AbstractSequence describes the abstract for a tuple or list.
class MIND_API AbstractSequence : public AbstractBase {
public:
MIND_API_BASE_MEMBER(AbstractSequence);

/// \brief Get element abstracts.
///
/// \return A vector of element abstracts.
AbstractBasePtrList elements() const;
};

using AbstractSequencePtr = SharedPtr<AbstractSequence>;

/// \brief AbstractTuple describes the abstract for a tuple.
class MIND_API AbstractTuple : public AbstractSequence {
public:
MIND_API_BASE_MEMBER(AbstractTuple);
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_

+ 232
- 0
mindspore/core/mindapi/ir/anf.h View File

@@ -0,0 +1,232 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_
#define MINDSPORE_CORE_MINDAPI_IR_ANF_H_

#include <vector>
#include <string>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/abstract.h"
#include "mindapi/ir/primitive.h"
#include "mindapi/ir/value.h"

namespace mindspore::api {
/// \brief AnfNode is the basic class of the IR graph node.
class MIND_API AnfNode : public Base {
public:
MIND_API_BASE_MEMBER(AnfNode);

/// \brief Obtain detailed information about scope namespace.
///
/// \return Detailed information about scope namespace.
std::string fullname_with_scope() const;

/// \brief Obtain the inferred abstract value of this AnfNode.
///
/// \return The inferred abstract value.
AbstractBasePtr abstract() const;

/// \brief Set the abstract value of this AnfNode.
///
/// \param[in] abs New abstract value.
void set_abstract(const AbstractBasePtr &abs);
};

/// \brief CNode represents a compute node with a set of input nodes.
class MIND_API CNode : public AnfNode {
public:
MIND_API_BASE_MEMBER(CNode);

/// \brief Get the number of inputs.
///
/// \return The number of inputs in this CNode.
size_t size() const;

/// \brief Get the input node of the given index.
///
/// \param[in] i The given index.
///
/// \return The input node of the given index.
AnfNodePtr input(size_t i) const;

/// \brief Get the input nodes.
///
/// \return The input nodes of this CNode.
std::vector<AnfNodePtr> inputs() const;

/// \brief Set the input nodes for this CNode.
///
/// \param[in] inputs Input nodes.
void set_inputs(const std::vector<AnfNodePtr> &inputs);

/// \brief Add an input node to this CNode.
///
/// \param[in] input the input node to be added.
void add_input(const AnfNodePtr &input);

/// \brief Set fullname_with_scope for this CNode.
///
/// \param[in] full_name The fullname_with_scope.
void set_fullname_with_scope(const std::string &full_name);

/// \brief Add a new attribute to this CNode.
///
/// \param[in] name The name of the new attribute.
/// \param[in] attr The value of the new attribute.
void AddAttr(const std::string &name, const ValuePtr &attr);

/// \brief Erase the attribute with the given name.
///
/// \param[in] name The name of attribute.
void EraseAttr(const std::string &name);

/// \brief Get the attribute with the given name.
///
/// \param[in] name The name of attribute.
/// \return Attribute.
ValuePtr GetAttr(const std::string &name) const;
};

using CNodePtr = SharedPtr<CNode>;

/// \brief Parameter represents the parameter inputs of a function.
class MIND_API Parameter : public AnfNode {
public:
MIND_API_BASE_MEMBER(Parameter);

/// \brief Get the name of this Parameter.
///
/// \return The name.
std::string name() const;

/// \brief Set the name of this Parameter.
///
/// \param[in] The name.
void set_name(const std::string &name);

/// \brief Check if there is a default parameter.
///
/// \return True if this Parameter has a default parameter, otherwise false.
bool has_default() const;

/// \brief Set the default parameter.
///
/// \param[in] param The default parameter.
void set_default_param(const ValuePtr &param);

/// \brief Get the default parameter.
///
/// \return The default parameter.
ValuePtr default_param() const;
};

using ParameterPtr = SharedPtr<Parameter>;

/// \brief ValueNode is a graph node that hold a value.
class MIND_API ValueNode : public AnfNode {
public:
MIND_API_BASE_MEMBER(ValueNode);

/// \brief Create ValueNode with the given value.
///
/// \param[in] value The value of this ValueNode.
explicit ValueNode(const ValuePtr &value);

/// \brief Get the value of this ValueNode.
///
/// \return The value.
ValuePtr value() const;
};

using ValueNodePtr = SharedPtr<ValueNode>;

// === ANF utility functions === //

/// \brief Create a ValueNode with the given value.
///
/// \param[in] value The given value.
///
/// \return The created ValueNode.
inline ValueNodePtr NewValueNode(const ValuePtr &value) { return MakeShared<ValueNode>(value); }

/// \brief Create a ValueNode with the given primitive type value.
///
/// \param[in] value The given primitive type value.
///
/// \return The created ValueNode.
template <typename T>
inline ValueNodePtr NewValueNode(T value) {
return NewValueNode(MakeValue(value));
}

/// \brief Get the value from a node if it is a ValueNode.
///
/// \param[in] node The node which may hold a value.
///
/// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set.
inline ValuePtr GetValueNode(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
return value_node->value();
}

/// \brief Get the value with the given type from a node if it is a ValueNode.
///
/// \param[in] node The node which may hold a value.
///
/// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set, or value type is mismatch.
template <typename T, typename = typename std::enable_if_t<
is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>>
inline T GetValueNode(const AnfNodePtr &node) {
auto value = GetValueNode(node);
if (value == nullptr) {
return nullptr;
}
return value->cast<T>();
}

/// \brief Check whether the given node is a cnode with the given Primitive as the first input.
///
/// \param[in] node The given node to be checked.
/// \param[in] prim The Primitive value, nullptr means match any Primitive.
///
/// \return True if the node is cnode and the first input is the given Primitive, false otherwise.
MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr);

/// \brief Check whether the given node is a ValueNode with the given Primitive.
///
/// \param[in] node The given node to be checked.
/// \param[in] prim The Primitive value.
///
/// \return True if the given node is a ValueNode with the given Primitive, false otherwise.
MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim);

/// \brief Check if a node is a data node.
/// Some nodes may be used internally to pass some non-data states, those nodes are not data nodes.
///
/// \param[in] node The node to be checked.
///
/// \return True if the node is a data node, false otherwise.
MIND_API bool IsDataNode(const AnfNodePtr &node);
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_ANF_H_

+ 50
- 0
mindspore/core/mindapi/ir/common.h View File

@@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_COMMON_H_
#define MINDSPORE_CORE_MINDAPI_IR_COMMON_H_

#include <vector>
#include "mindapi/base/shared_ptr.h"

namespace mindspore::api {
class AnfNode;
using AnfNodePtr = SharedPtr<AnfNode>;
using AnfNodePtrList = std::vector<AnfNodePtr>;

class Value;
using ValuePtr = SharedPtr<Value>;

class Primitive;
using PrimitivePtr = SharedPtr<Primitive>;

class Type;
using TypePtr = SharedPtr<Type>;

class AbstractBase;
using AbstractBasePtr = SharedPtr<AbstractBase>;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;

class Shape;
using ShapePtr = SharedPtr<Shape>;

class FuncGraph;
using FuncGraphPtr = SharedPtr<FuncGraph>;

class FuncGraphManager;
using FuncGraphManagerPtr = SharedPtr<FuncGraphManager>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_COMMON_H_

+ 193
- 0
mindspore/core/mindapi/ir/func_graph.h View File

@@ -0,0 +1,193 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
#define MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_

#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/anf.h"
#include "mindapi/ir/primitive.h"
#include "mindapi/ir/value.h"
#include "mindapi/ir/utils.h"

namespace mindspore {
class FuncGraphManager;
}

namespace mindspore::api {
/// \brief FuncGraph defines interface for a function graph.
class MIND_API FuncGraph : public Value {
public:
MIND_API_BASE_MEMBER(FuncGraph);

/// \brief Get the input parameters.
///
/// \return Input parameters of this graph.
std::vector<AnfNodePtr> get_inputs() const;

/// \brief Get all parameters.
///
/// \return All parameters of this graph.
std::vector<AnfNodePtr> parameters() const;

/// \brief Adds a parameter to this graph.
///
/// \param[in] p The parameter to be added.
void add_parameter(const ParameterPtr &p);

/// \brief Adds a new parameter to this graph.
///
/// \return The new added parameter.
ParameterPtr add_parameter();

/// \brief Get the output node.
///
/// \return The output node, nullptr if output not set.
AnfNodePtr output() const;

/// \brief Get the return CNode.
///
/// \return The return CNode, nullptr if no return node.
CNodePtr get_return() const;

/// \brief Set the output node.
///
/// \param[in] value The output node to be set.
/// \param[in] force_new_ret If true, a new return node is always created.
void set_output(const AnfNodePtr &value, bool force_new_ret = false);

/// \brief Set the return node.
///
/// \param[in] cnode The return CNode to be set.
void set_return(const CNodePtr &cnode);

/// \brief Creates a new CNode in this graph.
///
/// \param[in] inputs The input nodes of the new CNode.
///
/// \return The created CNode.
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());

/// \brief Creates a new primitive CNode in this graph.
///
/// \param[in] primitive The primitive of the new CNode.
/// \param[in] prim_inputs The argument inputs of the primitive CNode.
///
/// \return The created primitive CNode.
CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);

/// \brief Get all nodes in this graph.
///
/// \return All nodes in this graph.
std::vector<AnfNodePtr> nodes() const;

/// \brief Check whether an attribute is set for this graph.
///
/// \param[in] key The attribute key (name).
///
/// \return True if the attribute with the given key is set, false otherwise.
bool has_attr(const std::string &key) const;

/// \brief Get an attribute value by its key.
///
/// \param[in] key The attribute key (name).
///
/// \return The attribute value for the given key, nullptr if attribute not found.
ValuePtr get_attr(const std::string &key) const;

/// \brief Set an attribute value.
///
/// \param[in] key The attribute key (name).
/// \param[in] value The attribute value.
void set_attr(const std::string &key, const ValuePtr &value);

/// \brief Get the manager for this graph.
///
/// \return The manager of this graph, nullptr if not set.
FuncGraphManagerPtr manager() const;

/// \brief Creates an empty function graph.
///
/// \return The created function graph.
static FuncGraphPtr Create();

/// \brief Topological sort a graph from the given end node.
///
/// \param[in] node The end node of the graph to be sorted.
///
/// \return The sorted nodes.
static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node);
};

/// \brief FuncGraphManager defines interface for function graph management.
class MIND_API FuncGraphManager {
public:
/// \brief Create FuncGraphManager with the given implementor object.
///
/// \param[in] impl The pointer to the implementor object.
explicit FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl);

/// \brief Get the shared_ptr to the underly implementation object.
///
/// \return The shared_ptr to the underly implementation object.
const std::shared_ptr<mindspore::FuncGraphManager> &impl() const { return impl_; }

/// \brief Replace an old node with a new node, related edges are all updated.
///
/// \param[in] old_node The old node to be replaced.
/// \param[in] new_node The new node that replace the old one.
///
/// \return True if the node is successfully replaced, false otherwise.
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);

/// \brief Change an existed edge by replace its input node.
///
/// \param[in] node The output node of the edge.
/// \param[in] index The input index in output node.
/// \param[in] value The new input node of the edge.
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);

/// \brief Adds a new edge between the given two nodes.
///
/// \param[in] node The output node of the edge.
/// \param[in] value The input node of the edge.
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);

/// \brief Find users of the given node.
///
/// \param[in] node The node.
///
/// \return Users of the given node, empty if user not found.
std::vector<std::pair<AnfNodePtr, int>> GetUsers(const AnfNodePtr &node) const;

/// \brief Manage the give function graph.
///
/// \param[in] func_graph The function graph to be managed.
/// \param[in] manage If true, the created manager will be set in the graph.
///
/// \return The manager that manages the given function graph.
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);

private:
const std::shared_ptr<mindspore::FuncGraphManager> impl_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_

+ 79
- 0
mindspore/core/mindapi/ir/primitive.h View File

@@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_
#define MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_

#include <vector>
#include <string>
#include <unordered_map>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/value.h"

namespace mindspore::api {
/// \brief Primitive defines a primitive operator.
class MIND_API Primitive : public Value {
public:
MIND_API_BASE_MEMBER(Primitive);

/// \brief Create primitive with the given name.
///
/// \param[in] name The primitive name.
explicit Primitive(const std::string &name);

/// \brief Get name of the primitive.
///
/// \return The name of primitive.
const std::string &name() const;

/// \brief Add attribute to primitive.
///
/// \param[in] name The attribute name.
/// \param[in] attr The attribute value.
/// \return The primitive to which attribute has been added.
Primitive &AddAttr(const std::string &name, const ValuePtr &attr);

/// \brief Add attributes by using a map, all elements of the map will be added to this primitive.
///
/// \param[in] attrs The attribute map needs to be added in the primitive attribute.
/// \return The primitive to which attribute has been added.
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs);

/// \brief Erase attribute to the primitive attribute map.
///
/// \param[in] name The attribute name.
void EraseAttr(const std::string &name);

/// \brief Get attribute value by name.
///
/// \param[in] name the attribute name.
/// \return The value of the attribute, null if attribute name not found.
ValuePtr GetAttr(const std::string &name) const;

/// \brief Check If Primitive has an attribute with then given name.
///
/// \param[in] name The attribute name.
/// \return True if there is an attribute with the given name, otherwise false.
bool HasAttr(const std::string &name) const;

/// \brief Get all attributes of this primitive as a map.
///
/// \return The attribute map of this primitive.
std::unordered_map<std::string, ValuePtr> attrs() const;
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_PRIMITIVE_H_

+ 36
- 0
mindspore/core/mindapi/ir/shape.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_
#define MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_

#include "mindapi/base/base.h"
#include "mindapi/base/shape_vector.h"
#include "mindapi/ir/common.h"

namespace mindspore::api {
/// \brief Shape defines dimensions of a tensor.
class MIND_API Shape : public Base {
public:
MIND_API_BASE_MEMBER(Shape);

/// \brief Get the shape dimensions.
///
/// \return The shape dimensions.
const ShapeVector &shape() const;
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_SHAPE_H_

+ 93
- 0
mindspore/core/mindapi/ir/tensor.h View File

@@ -0,0 +1,93 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_
#define MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_

#include <cstdint>
#include "mindapi/base/base.h"
#include "mindapi/base/shape_vector.h"
#include "mindapi/base/type_id.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/value.h"

namespace mindspore::api {
/// \brief Tensor represents a multi-dimensional array of elements.
class MIND_API Tensor : public Value {
public:
MIND_API_BASE_MEMBER(Tensor);

/// \brief Create a lazy allocated tensor.
///
/// \param[in] data_type [TypeId] Data type of the tensor.
/// \param[in] shape The shape represented by ShapeVector of the tensor.
Tensor(TypeId data_type, const ShapeVector &shape);

/// \brief Create a tensor with input data buffer.
///
/// \param[in] data_type [TypeId] Data type of the tensor.
/// \param[in] shape The shape represented by ShapeVector of the tensor.
/// \param[in] data The input data to be copied into tensor.
/// \param[in] data_len The length of data in bytes.
Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);

/// \brief Get the shape of the tensor.
/// The shape of a tensor is stored in a vector<int64_t>. Each element of the
/// vector represents the size of a dimension of the tensor. The order of each
/// element in the vector is the same as the the dimension's order it represents.
///
/// \return A vector<int64_t> which represents the shape of the tensor.
const ShapeVector &shape() const;

/// \brief Set the shape of tensor.
///
/// \param[in] shape The shape to be set.
void set_shape(const ShapeVector &shape);

/// \brief Get the data type of the tensor.
///
/// \return The data type of the tensor.
TypeId data_type() const;

/// \brief Set the data type of the tensor.
///
/// \param[in] data_type The data type to be set.
void set_data_type(const TypeId data_type);

/// \brief Get The pointer to the underlying memory block for data storage.
///
/// \return The pointer to the underlying data.
const void *data() const;

/// \brief Get The pointer to the underlying memory block for data storage.
///
/// \return The pointer to the underlying data.
void *data();

/// \brief Get tensor data size.
///
/// \return The total number of elements in the tensor.
int DataSize() const;

/// \brief Get tensor data size in bytes.
///
/// \return The total number of bytes for the tensor data.
std::size_t Size() const;
};

using TensorPtr = SharedPtr<Tensor>;
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_TENSOR_H_

+ 56
- 0
mindspore/core/mindapi/ir/type.h View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_TYPE_H_
#define MINDSPORE_CORE_MINDAPI_IR_TYPE_H_

#include "mindapi/base/base.h"
#include "mindapi/base/type_id.h"
#include "mindapi/ir/common.h"
#include "mindapi/ir/value.h"

namespace mindspore::api {
/// \brief Type defines the type of a value.
class MIND_API Type : public Value {
public:
MIND_API_BASE_MEMBER(Type);

/// \brief Get the id of the Type object.
///
/// \return The id of the Type object.
TypeId type_id() const;

/// \brief Get the number type of the Type object.
///
/// \return The number type of this Type object, kTypeUnknown if this is not a number type.
TypeId number_type() const;

/// \brief Get the Type according to a TypeId.
///
/// \param[in] id The id of the type.
///
/// \return The pointer to the Type.
static TypePtr GetType(TypeId id);

/// \brief Get data size in bytes for the type according to a TypeId.
///
/// \param[in] id The id of the type.
///
/// \return The data size in bytes for the Type.
static size_t GetSize(TypeId id);
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_TYPE_H_

+ 70
- 0
mindspore/core/mindapi/ir/utils.h View File

@@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_UTILS_H_
#define MINDSPORE_CORE_MINDAPI_IR_UTILS_H_

#include "mindapi/base/base.h"
#include "mindapi/base/shared_ptr.h"
#include "mindapi/base/type_traits.h"
#include "mindapi/ir/anf.h"
#include "mindapi/ir/value.h"
#include "mindapi/ir/func_graph.h"

namespace mindspore::api::utils {
/// \brief Check whether the given object is an instance of the given class.
///
/// \param[in] ptr The pointer to the given object.
///
/// \return True if the pointer is not null and the object is an instance of the given class, false otherwise.
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>, T>>
bool isa(const BasePtr &ptr) {
if (ptr == nullptr) {
return false;
}
return ptr->isa<T>();
}

/// \brief Cast the given object pointer to a pointer with the given class.
///
/// \param[in] ptr The pointer to the object to casted.
///
/// \return A non-null pointer if the input pointer is not null and cast success, nullptr otherwise.
template <typename T, typename = typename std::enable_if_t<is_wrapper_ptr<T>::value, T>>
T cast(const BasePtr &ptr) {
if (ptr == nullptr) {
return nullptr;
}
return ptr->cast<T>();
}

/// \brief Make a copy from the given function graph.
///
/// \param[in] func_graph The graph to be cloned.
///
/// \return The cloned graph.
MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph);

/// \brief Get pad mode id from a value holds the pad mode name or id.
///
/// \param[in] value The value holds the pad mode name or id.
/// \param[in] is_upper Indicates whether the name is uppercase or lowercase, default is false for lowercase.
///
/// \return The pad mode id.
MIND_API int64_t GetPadMode(const ValuePtr &value, bool is_upper = false);
} // namespace mindspore::api::utils

#endif // MINDSPORE_CORE_MINDAPI_IR_UTILS_H_

+ 270
- 0
mindspore/core/mindapi/ir/value.h View File

@@ -0,0 +1,270 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
#define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_

#include <vector>
#include <string>
#include <type_traits>
#include "mindapi/base/base.h"
#include "mindapi/ir/common.h"

namespace mindspore::api {
template <typename T>
struct ImmTrait {};

#define MIND_API_IMM_TRAIT(typeimm, prototype) \
template <> \
struct ImmTrait<prototype> { \
using type = SharedPtr<typeimm>; \
}

/// \brief Value represents a value in expression.
class MIND_API Value : public Base {
public:
MIND_API_BASE_MEMBER(Value);

/// \brief Get the type of this Value.
///
/// \return The type.
TypePtr type() const;

/// \brief Get the abstract of this Value.
///
/// \return Abstract of this Value.
AbstractBasePtr ToAbstract() const;
};

/// \brief ValueSequence represents a sequence of values.
class MIND_API ValueSequence : public Value {
public:
MIND_API_BASE_MEMBER(ValueSequence);

/// \brief Get the size of this ValueSequence.
///
/// \return The size as the number of elements.
std::size_t size() const;

/// \brief Get the list of values in this ValueSequence.
///
/// \return The list of element values.
std::vector<ValuePtr> value() const;
};

using ValueSequencePtr = SharedPtr<ValueSequence>;

/// \brief ValueTuple represents a value tuple.
class MIND_API ValueTuple : public ValueSequence {
public:
MIND_API_BASE_MEMBER(ValueTuple);

/// \brief Constructor of ValueTuple.
///
/// \param[in] elements The elements of the tuple.
explicit ValueTuple(const std::vector<ValuePtr> &elements);
};

using ValueTuplePtr = SharedPtr<ValueTuple>;

/// \brief StringImm defines a Value whose type is string.
class MIND_API StringImm : public Value {
public:
MIND_API_BASE_MEMBER(StringImm);

/// \brief Create StringImm with the given string.
///
/// \param[in] str The given string value.
explicit StringImm(const std::string &str);

/// \brief Get the string value of this StringImm.
///
/// \return The string value of this StringImm.
const std::string &value() const;
};

using StringImmPtr = SharedPtr<StringImm>;

MIND_API_IMM_TRAIT(StringImm, std::string);

/// \beief Scalar defines interface for scalar data.
class MIND_API Scalar : public Value {
public:
MIND_API_BASE_MEMBER(Scalar);
};

/// \beief BoolImm defines interface for bool data.
class MIND_API BoolImm : public Scalar {
public:
MIND_API_BASE_MEMBER(BoolImm);

/// \brief Create BoolImm with the given bool value.
///
/// \param[in] b The given bool value.
explicit BoolImm(bool b);

/// \brief Get the bool value of this BoolImm.
///
/// \return The bool value of this BoolImm.
bool value() const;
};

using BoolImmPtr = SharedPtr<BoolImm>;

MIND_API_IMM_TRAIT(BoolImm, bool);

/// \beief IntegerImm defines interface for integer data.
class MIND_API IntegerImm : public Scalar {
public:
MIND_API_BASE_MEMBER(IntegerImm);
};

/// \beief Int64Imm defines interface for int64 data.
class MIND_API Int64Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int64Imm);

/// \brief Create Int64Imm with the given int64 value.
///
/// \param[in] value The given bool value.
explicit Int64Imm(int64_t value);

/// \brief Get the int64 value of this Int64Imm.
///
/// \return The int64 value of this Int64Imm.
int64_t value() const;
};

using Int64ImmPtr = SharedPtr<Int64Imm>;

MIND_API_IMM_TRAIT(Int64Imm, int64_t);

/// \beief FloatImm defines interface for float data.
class MIND_API FloatImm : public Scalar {
public:
MIND_API_BASE_MEMBER(FloatImm);
};

/// \beief FP32Imm defines interface for float32 data.
class MIND_API FP32Imm : public FloatImm {
public:
MIND_API_BASE_MEMBER(FP32Imm);

/// \brief Create FP32Imm with the given float value.
///
/// \param[in] value The given float value.
explicit FP32Imm(float value);

/// \brief Get the float value of this FP32Imm.
///
/// \return The float value of this FP32Imm.
float value() const;
};

using FP32ImmPtr = SharedPtr<FP32Imm>;

MIND_API_IMM_TRAIT(FP32Imm, float);

// === Utility functions for Value === //

/// \brief brief Create a Value object from a primitive type value.
///
/// \param[in] v The primitive type value.
///
/// \return The created Value object with the given primitive type value.
template <typename T, typename U = typename ImmTrait<T>::type::element_type>
inline ValuePtr MakeValue(T v) {
return MakeShared<U>(v);
}

/// \brief brief Create a StringImm Value object from a C string.
///
/// \param[in] s The C string.
///
/// \return The created StringImm Value object.
inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }

/// \brief brief Create a Int64Imm Value object from a int value.
///
/// \param[in] i The int value.
///
/// \return The created Int64Imm Value object.
inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }

/// \brief brief Create a ValueSequence object from a vector of values.
///
/// \param[in] values The vector of values.
///
/// \return The created ValueSequence object.
inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }

/// \brief Create a ValueSequence object from a vector of primitive type values.
///
/// \param[in] values The vector of primitive values.
///
/// \return The created ValueSequence object.
template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
inline ValuePtr MakeValue(const T &values) {
std::vector<ValuePtr> value_vector;
value_vector.reserve(values.size());
for (auto &value : values) {
value_vector.emplace_back(MakeValue(value));
}
return MakeShared<ValueTuple>(value_vector);
}

/// \brief brief Get primitive type value from a Value object.
///
/// \param[in] value The pointer to the Value object.
///
/// \return The primitive type value of the Value object.
template <typename T, typename U = typename ImmTrait<T>::type>
inline T GetValue(const ValuePtr &value) {
if (value == nullptr) {
return T();
}
U imm = value->cast<U>();
if (imm == nullptr) {
return T();
}
return imm->value();
}

/// \brief brief Get primitive element values from a ValueSequeue object.
///
/// \param[in] value The pointer to the ValueSequeue object.
///
/// \return The primitive type values as a vector.
template <typename T, typename S = typename std::decay_t<T>,
typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
inline std::vector<U> GetValue(const ValuePtr &value) {
if (value == nullptr) {
return {};
}
auto seq = value->cast<ValueSequencePtr>();
if (seq == nullptr) {
return {};
}
auto elements = seq->value();
std::vector<U> result;
result.reserve(elements.size());
for (auto &e : elements) {
result.emplace_back(GetValue<U>(e));
}
return result;
}
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_

+ 84
- 0
mindspore/core/mindapi/src/abstract.cc View File

@@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/abstract.h"
#include "mindapi/src/helper.h"
#include "abstract/abstract_value.h"
#include "ir/dtype.h"
#include "ir/value.h"

namespace mindspore::api {
using TypeImpl = mindspore::Type;
using ValueImpl = mindspore::Value;
using AbstractBaseImpl = mindspore::abstract::AbstractBase;

MIND_API_BASE_IMPL(AbstractBase, AbstractBaseImpl, Base);

AbstractBasePtr AbstractBase::Clone() const {
auto abs = ToRef<AbstractBaseImpl>(impl_).Clone();
return ToWrapper<AbstractBase>(abs);
}

TypePtr AbstractBase::type() const {
auto t = ToRef<AbstractBaseImpl>(impl_).BuildType();
return ToWrapper<Type>(t);
}

ValuePtr AbstractBase::value() const {
auto v = ToRef<AbstractBaseImpl>(impl_).BuildValue();
return ToWrapper<Value>(v);
}

void AbstractBase::set_type(const TypePtr &type) {
auto type_impl = ToImpl<TypeImpl>(type);
ToRef<AbstractBaseImpl>(impl_).set_type(type_impl);
}

void AbstractBase::set_value(const ValuePtr &value) {
auto value_impl = ToImpl<ValueImpl>(value);
ToRef<AbstractBaseImpl>(impl_).set_value(value_impl);
}

using AbstractTensorImpl = mindspore::abstract::AbstractTensor;

MIND_API_BASE_IMPL(AbstractTensor, AbstractTensorImpl, AbstractBase);

AbstractTensor::AbstractTensor(TypeId type, const ShapeVector &shape)
: AbstractBase(std::make_shared<AbstractTensorImpl>(mindspore::TypeIdToType(type), shape)) {}

AbstractBasePtr AbstractTensor::element() const {
auto abs = ToRef<AbstractTensorImpl>(impl_).element();
return ToWrapper<AbstractBase>(abs);
}

ShapePtr AbstractTensor::shape() const {
auto s = ToRef<AbstractTensorImpl>(impl_).shape();
return ToWrapper<Shape>(s);
}

using AbstractSequenceImpl = mindspore::abstract::AbstractSequeue;

MIND_API_BASE_IMPL(AbstractSequence, AbstractSequenceImpl, AbstractBase);

AbstractBasePtrList AbstractSequence::elements() const {
auto &impl_elements = ToRef<AbstractSequenceImpl>(impl_).elements();
return ToWrapperVector<AbstractBase>(impl_elements);
}

using AbstractTupleImpl = mindspore::abstract::AbstractTuple;

MIND_API_BASE_IMPL(AbstractTuple, AbstractTupleImpl, AbstractSequence);
} // namespace mindspore::api

+ 135
- 0
mindspore/core/mindapi/src/anf.cc View File

@@ -0,0 +1,135 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/anf.h"
#include "mindapi/src/helper.h"
#include "ir/anf.h"
#include "ir/value.h"
#include "ir/primitive.h"
#include "abstract/abstract_value.h"

namespace mindspore::api {
using ValueImpl = mindspore::Value;
using AnfNodeImpl = mindspore::AnfNode;
using PrimitiveImpl = mindspore::Primitive;
using AbstractBaseImpl = mindspore::abstract::AbstractBase;

MIND_API_BASE_IMPL(AnfNode, AnfNodeImpl, Base);

std::string AnfNode::fullname_with_scope() const { return ToRef<AnfNodeImpl>(impl_).fullname_with_scope(); }

AbstractBasePtr AnfNode::abstract() const {
const auto &abs = ToRef<AnfNodeImpl>(impl_).abstract();
return ToWrapper<AbstractBase>(abs);
}

void AnfNode::set_abstract(const AbstractBasePtr &abs) {
ToRef<AnfNodeImpl>(impl_).set_abstract(ToImpl<AbstractBaseImpl>(abs));
}

using CNodeImpl = mindspore::CNode;

MIND_API_BASE_IMPL(CNode, CNodeImpl, AnfNode);

size_t CNode::size() const { return ToRef<CNodeImpl>(impl_).size(); }

AnfNodePtr CNode::input(size_t i) const {
auto &input = ToRef<CNodeImpl>(impl_).input(i);
return ToWrapper<AnfNode>(input);
}

std::vector<AnfNodePtr> CNode::inputs() const {
auto &impl_inputs = ToRef<CNodeImpl>(impl_).inputs();
return ToWrapperVector<AnfNode>(impl_inputs);
}

void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) {
auto impl_inputs = ToImplVector<AnfNodeImpl>(inputs);
ToRef<CNodeImpl>(impl_).set_inputs(impl_inputs);
}

void CNode::add_input(const AnfNodePtr &input) {
auto impl_input = ToImpl<AnfNodeImpl>(input);
MS_EXCEPTION_IF_NULL(impl_input);
ToRef<CNodeImpl>(impl_).add_input(impl_input);
}

void CNode::set_fullname_with_scope(const std::string &full_name) {
ToRef<CNodeImpl>(impl_).set_fullname_with_scope(full_name);
}

void CNode::AddAttr(const std::string &name, const ValuePtr &attr) {
auto impl_attr = ToImpl<ValueImpl>(attr);
MS_EXCEPTION_IF_NULL(impl_attr);
ToRef<CNodeImpl>(impl_).AddAttr(name, impl_attr);
}

void CNode::EraseAttr(const std::string &name) { ToRef<CNodeImpl>(impl_).EraseAttr(name); }

ValuePtr CNode::GetAttr(const std::string &name) const {
auto v = ToRef<CNodeImpl>(impl_).GetAttr(name);
return ToWrapper<Value>(v);
}

using ParameterImpl = mindspore::Parameter;

MIND_API_BASE_IMPL(Parameter, ParameterImpl, AnfNode);

std::string Parameter::name() const { return ToRef<ParameterImpl>(impl_).name(); }

void Parameter::set_name(const std::string &name) { ToRef<ParameterImpl>(impl_).set_name(name); }

bool Parameter::has_default() const { return ToRef<ParameterImpl>(impl_).has_default(); }

void Parameter::set_default_param(const ValuePtr &param) {
auto v = ToImpl<ValueImpl>(param);
ToRef<ParameterImpl>(impl_).set_default_param(v);
}

ValuePtr Parameter::default_param() const {
auto v = ToRef<ParameterImpl>(impl_).default_param();
return ToWrapper<Value>(v);
}

using ValueNodeImpl = mindspore::ValueNode;

MIND_API_BASE_IMPL(ValueNode, ValueNodeImpl, AnfNode);

ValueNode::ValueNode(const ValuePtr &value) : AnfNode(std::make_shared<ValueNodeImpl>(ToImpl<ValueImpl>(value))) {}

ValuePtr ValueNode::value() const {
auto v = ToRef<ValueNodeImpl>(impl_).value();
return ToWrapper<Value>(v);
}

bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
auto prim_impl = ToImpl<PrimitiveImpl>(prim);
return mindspore::IsPrimitiveCNode(node_impl, prim_impl);
}

bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
auto prim_impl = ToImpl<PrimitiveImpl>(prim);
return mindspore::IsPrimitive(node_impl, prim_impl);
}

bool IsDataNode(const AnfNodePtr &node) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
// We assume that node with monad abstract is not a data node.
return !HasAbstractMonad(node_impl);
}
} // namespace mindspore::api

+ 28
- 0
mindspore/core/mindapi/src/base.cc View File

@@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/base/base.h"
#include "base/base.h"

namespace mindspore::api {
Base::Base(const std::shared_ptr<mindspore::Base> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl_); }

uint32_t Base::ClassId() { return mindspore::Base::kTypeId; }

bool Base::IsFromClassId(uint32_t class_id) const { return impl_->IsFromTypeId(class_id); }

std::string Base::ToString() const { return impl_->ToString(); }
} // namespace mindspore::api

+ 170
- 0
mindspore/core/mindapi/src/func_graph.cc View File

@@ -0,0 +1,170 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include "mindapi/ir/func_graph.h"
#include "mindapi/src/helper.h"
#define USE_DEPRECATED_API
#include "ir/anf.h"
#include "ir/value.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "ir/primitive.h"
#include "ir/graph_utils.h"

namespace mindspore::api {
using ValueImpl = mindspore::Value;
using AnfNodeImpl = mindspore::AnfNode;
using CNodeImpl = mindspore::CNode;
using PrimitiveImpl = mindspore::Primitive;
using ParameterImpl = mindspore::Parameter;
using FuncGraphImpl = mindspore::FuncGraph;
using FuncGraphManagerImpl = mindspore::FuncGraphManager;

MIND_API_BASE_IMPL(FuncGraph, FuncGraphImpl, Value);

std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
auto &inputs = ToRef<FuncGraphImpl>(impl_).get_inputs();
return ToWrapperVector<AnfNode>(inputs);
}

std::vector<AnfNodePtr> FuncGraph::parameters() const {
auto &params = ToRef<FuncGraphImpl>(impl_).parameters();
return ToWrapperVector<AnfNode>(params);
}

void FuncGraph::add_parameter(const ParameterPtr &p) {
auto param_impl = ToImpl<ParameterImpl>(p);
ToRef<FuncGraphImpl>(impl_).add_parameter(param_impl);
}

ParameterPtr FuncGraph::add_parameter() {
auto param_impl = ToRef<FuncGraphImpl>(impl_).add_parameter();
return ToWrapper<Parameter>(param_impl);
}

AnfNodePtr FuncGraph::output() const {
auto output = ToRef<FuncGraphImpl>(impl_).output();
return ToWrapper<AnfNode>(output);
}

CNodePtr FuncGraph::get_return() const {
auto ret = ToRef<FuncGraphImpl>(impl_).get_return();
return ToWrapper<CNode>(ret);
}

void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
auto output = ToImpl<AnfNodeImpl>(value);
ToRef<FuncGraphImpl>(impl_).set_output(output);
}

void FuncGraph::set_return(const CNodePtr &cnode) {
auto cnode_impl = ToImpl<CNodeImpl>(cnode);
ToRef<FuncGraphImpl>(impl_).set_return(cnode_impl);
}

CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
auto inputs_impl = ToImplVector<AnfNodeImpl>(inputs);
auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(std::move(inputs_impl));
return ToWrapper<CNode>(cnode_impl);
}

CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) {
auto prim_impl = ToImpl<PrimitiveImpl>(primitive);
auto prim_inputs_impl = ToImplVector<AnfNodeImpl>(prim_inputs);
auto cnode_impl = ToRef<FuncGraphImpl>(impl_).NewCNode(prim_impl, prim_inputs_impl);
return ToWrapper<CNode>(cnode_impl);
}

std::vector<AnfNodePtr> FuncGraph::nodes() const {
auto &nodes = ToRef<FuncGraphImpl>(impl_).nodes();
return ToWrapperVector<AnfNode>(nodes);
}

bool FuncGraph::has_attr(const std::string &key) const { return ToRef<FuncGraphImpl>(impl_).has_attr(key); }

ValuePtr FuncGraph::get_attr(const std::string &key) const {
auto v = ToRef<FuncGraphImpl>(impl_).get_attr(key);
return ToWrapper<Value>(v);
}

void FuncGraph::set_attr(const std::string &key, const ValuePtr &value) {
auto value_impl = ToImpl<ValueImpl>(value);
ToRef<FuncGraphImpl>(impl_).set_attr(key, value_impl);
}

FuncGraphManagerPtr FuncGraph::manager() const {
auto manager = ToRef<FuncGraphImpl>(impl_).manager();
if (manager == nullptr) {
return nullptr;
}
return MakeShared<FuncGraphManager>(manager);
}

FuncGraphPtr FuncGraph::Create() {
auto fg = std::make_shared<FuncGraphImpl>();
return ToWrapper<FuncGraph>(fg);
}

std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) {
auto node_impl = ToImpl<AnfNodeImpl>(node);
if (node_impl == nullptr) {
return {};
}
auto sorted = mindspore::TopoSort(node_impl);
return ToWrapperVector<AnfNode>(sorted);
}

// FuncGraphManager is not derived from Base, we implement it directly.
FuncGraphManager::FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl) : impl_(impl) {
MS_EXCEPTION_IF_NULL(impl_);
}

bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
return impl_->Replace(ToImpl<AnfNodeImpl>(old_node), ToImpl<AnfNodeImpl>(new_node));
}

void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
return impl_->SetEdge(ToImpl<AnfNodeImpl>(node), index, ToImpl<AnfNodeImpl>(value));
}

void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
return impl_->AddEdge(ToImpl<AnfNodeImpl>(node), ToImpl<AnfNodeImpl>(value));
}

std::vector<std::pair<AnfNodePtr, int>> FuncGraphManager::GetUsers(const AnfNodePtr &node) const {
auto &node_users = impl_->node_users();
auto iter = node_users.find(ToImpl<AnfNodeImpl>(node));
if (iter == node_users.end()) {
return {};
}
auto &users_impl = iter->second;
std::vector<std::pair<AnfNodePtr, int>> users;
users.reserve(users_impl.size());
std::transform(users_impl.begin(), users_impl.end(), std::back_inserter(users),
[](const auto &user) { return std::make_pair(ToWrapper<AnfNode>(user.first), user.second); });
return users;
}

FuncGraphManagerPtr FuncGraphManager::Manage(const FuncGraphPtr &func_graph, bool manage) {
auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
auto mgr_impl = mindspore::Manage(fg_impl, manage);
if (mgr_impl == nullptr) {
return nullptr;
}
return MakeShared<FuncGraphManager>(mgr_impl);
}
} // namespace mindspore::api

+ 76
- 0
mindspore/core/mindapi/src/helper.h View File

@@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_
#define MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_

#include <memory>
#include <vector>
#include <type_traits>
#include "mindapi/base/base.h"

namespace mindspore::api {
template <typename T, typename U>
T &ToRef(const std::shared_ptr<U> &ptr) {
return static_cast<T &>(*ptr);
}

template <typename T, typename U, typename = typename std::enable_if_t<std::is_base_of_v<mindspore::Base, T>>,
typename = typename std::enable_if_t<std::is_base_of_v<Base, U>>>
std::shared_ptr<T> ToImpl(const SharedPtr<U> &wrapper) {
if (wrapper == nullptr || wrapper->impl() == nullptr) {
return nullptr;
}
return std::dynamic_pointer_cast<T>(wrapper->impl());
}

template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Base, T>>>
SharedPtr<T> ToWrapper(const std::shared_ptr<mindspore::Base> &impl) {
if (impl == nullptr) {
return nullptr;
}
return MakeShared<T>(impl);
}

template <typename T, typename U>
std::vector<std::shared_ptr<T>> ToImplVector(const U &wrapper_vector) {
std::vector<std::shared_ptr<T>> impl_vector;
impl_vector.reserve(wrapper_vector.size());
for (auto &wrapper : wrapper_vector) {
impl_vector.emplace_back(ToImpl<T>(wrapper));
}
return impl_vector;
}

template <typename T, typename U>
std::vector<SharedPtr<T>> ToWrapperVector(const U &impl_vector) {
std::vector<SharedPtr<T>> wrapper_vector;
wrapper_vector.reserve(impl_vector.size());
for (auto &impl : impl_vector) {
wrapper_vector.emplace_back(ToWrapper<T>(impl));
}
return wrapper_vector;
}

#define MIND_API_BASE_IMPL(current_class, impl_class, base_class) \
current_class::current_class(const std::shared_ptr<mindspore::Base> &impl) : base_class(impl) { \
if (!impl_->isa<impl_class>()) { \
MS_LOG(EXCEPTION) << "Wrong impl " << impl_->type_name() << " for " << #current_class; \
} \
} \
uint32_t current_class::ClassId() { return impl_class::kTypeId; }
} // namespace mindspore::api
#endif // MINDSPORE_CORE_MINDAPI_IMPL_HELPER_H_

+ 75
- 0
mindspore/core/mindapi/src/logging.cc View File

@@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#define MIND_LOG_NO_MS_LOG
#include "mindapi/base/logging.h"
#include "utils/log_adapter.h"

namespace mindspore::api {
static MsLogLevel ToMsLogLevel(LogLevel level) {
switch (level) {
case LogLevel::DEBUG:
return MsLogLevel::DEBUG;
case LogLevel::INFO:
return MsLogLevel::INFO;
case LogLevel::WARNING:
return MsLogLevel::WARNING;
case LogLevel::ERROR:
return MsLogLevel::ERROR;
case LogLevel::EXCEPTION:
return MsLogLevel::EXCEPTION;
default:
return MsLogLevel::EXCEPTION;
}
}

class LogWriterImpl {
public:
LogWriterImpl(LogLevel level, const char *file, int line, const char *func)
: writer_(LocationInfo(file, line, func), ToMsLogLevel(level), SubModuleId::SM_API) {}

~LogWriterImpl() = default;

void Write(const LogStream &stream) const noexcept {
mindspore::LogStream log_stream;
log_stream << stream.stream_.rdbuf();
writer_ < log_stream;
}

void WriteAndThrow(const LogStream &stream) const __attribute__((noreturn)) {
mindspore::LogStream log_stream;
log_stream << stream.stream_.rdbuf();
writer_ ^ log_stream;
}

private:
mindspore::LogWriter writer_;
};

LogWriter::LogWriter(LogLevel level, const char *file, int line, const char *func)
: impl_(std::make_unique<LogWriterImpl>(level, file, line, func)) {}

LogWriter::~LogWriter() = default;

void LogWriter::operator<(const LogStream &stream) const noexcept { impl_->Write(stream); }

void LogWriter::operator^(const LogStream &stream) const { impl_->WriteAndThrow(stream); }

bool LogWriter::IsEnabled(LogLevel level) {
auto log_level = ToMsLogLevel(level);
return IS_OUTPUT_ON(log_level);
}
} // namespace mindspore::api

+ 65
- 0
mindspore/core/mindapi/src/primitive.cc View File

@@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/primitive.h"
#include "mindapi/src/helper.h"
#include "ir/primitive.h"
#include "ir/value.h"

namespace mindspore::api {
using ValueImpl = mindspore::Value;
using PrimitiveImpl = mindspore::Primitive;

MIND_API_BASE_IMPL(Primitive, PrimitiveImpl, Value);

Primitive::Primitive(const std::string &name) : Value(std::make_shared<PrimitiveImpl>(name)) {}

const std::string &Primitive::name() const { return ToRef<PrimitiveImpl>(impl_).name(); }

Primitive &Primitive::AddAttr(const std::string &name, const ValuePtr &attr) {
auto value = ToImpl<ValueImpl>(attr);
ToRef<PrimitiveImpl>(impl_).set_attr(name, value);
return *this;
}

Primitive &Primitive::SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
auto value = ToImpl<ValueImpl>(attr.second);
ToRef<PrimitiveImpl>(impl_).set_attr(attr.first, value);
}
return *this;
}

void Primitive::EraseAttr(const std::string &name) { ToRef<PrimitiveImpl>(impl_).EraseAttr(name); }

ValuePtr Primitive::GetAttr(const std::string &name) const {
auto v = ToRef<PrimitiveImpl>(impl_).GetAttr(name);
return ToWrapper<Value>(v);
}

bool Primitive::HasAttr(const std::string &name) const { return ToRef<PrimitiveImpl>(impl_).HasAttr(name); }

std::unordered_map<std::string, ValuePtr> Primitive::attrs() const {
std::unordered_map<std::string, ValuePtr> attr_map;
auto &impl_attrs = ToRef<PrimitiveImpl>(impl_).attrs();
attr_map.reserve(impl_attrs.size());
for (auto &attr : impl_attrs) {
auto value = ToWrapper<Value>(attr.second);
attr_map.emplace(attr.first, value);
}
return attr_map;
}
} // namespace mindspore::api

+ 27
- 0
mindspore/core/mindapi/src/shape.cc View File

@@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/shape.h"
#include "mindapi/src/helper.h"
#include "abstract/dshape.h"

namespace mindspore::api {
using ShapeImpl = mindspore::abstract::Shape;

MIND_API_BASE_IMPL(Shape, ShapeImpl, Base);

const ShapeVector &Shape::shape() const { return ToRef<ShapeImpl>(impl_).shape(); }
} // namespace mindspore::api

+ 47
- 0
mindspore/core/mindapi/src/tensor.cc View File

@@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include "mindapi/ir/tensor.h"
#include "mindapi/src/helper.h"
#include "ir/tensor.h"

namespace mindspore::api {
using TensorImpl = mindspore::tensor::Tensor;

MIND_API_BASE_IMPL(Tensor, TensorImpl, Value);

Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : Value(std::make_shared<TensorImpl>(data_type, shape)) {}

Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
: Value(std::make_shared<TensorImpl>(data_type, shape, data, data_len)) {}

const ShapeVector &Tensor::shape() const { return ToRef<TensorImpl>(impl_).shape(); }

void Tensor::set_shape(const ShapeVector &shape) { (void)ToRef<TensorImpl>(impl_).set_shape(shape); }

TypeId Tensor::data_type() const { return ToRef<TensorImpl>(impl_).data_type(); }

void Tensor::set_data_type(const TypeId data_type) { (void)ToRef<TensorImpl>(impl_).set_data_type(data_type); }

const void *Tensor::data() const { return ToRef<TensorImpl>(impl_).data_c(); }

void *Tensor::data() { return ToRef<TensorImpl>(impl_).data_c(); }

int Tensor::DataSize() const { return ToRef<TensorImpl>(impl_).DataSize(); }

size_t Tensor::Size() const { return ToRef<TensorImpl>(impl_).Size(); }
} // namespace mindspore::api

+ 39
- 0
mindspore/core/mindapi/src/type.cc View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/type.h"
#include "mindapi/ir/value.h"
#include "mindapi/src/helper.h"
#include "ir/dtype/type.h"
#include "ir/dtype.h"
#include "abstract/utils.h"

namespace mindspore::api {
using TypeImpl = mindspore::Type;

MIND_API_BASE_IMPL(Type, TypeImpl, Value);

TypeId Type::type_id() const { return ToRef<TypeImpl>(impl_).type_id(); }

TypeId Type::number_type() const { return ToRef<TypeImpl>(impl_).number_type(); }

TypePtr Type::GetType(TypeId id) {
auto type_impl = mindspore::TypeIdToType(id);
return ToWrapper<Type>(type_impl);
}

size_t Type::GetSize(TypeId id) { return mindspore::abstract::TypeIdSize(id); }
} // namespace mindspore::api

+ 43
- 0
mindspore/core/mindapi/src/utils.cc View File

@@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/utils.h"
#include "mindapi/src/helper.h"
#define USE_DEPRECATED_API
#include "ir/anf.h"
#include "ir/value.h"
#include "ir/func_graph_cloner.h"
#include "utils/check_convert_utils.h"

namespace mindspore::api::utils {
using ValueImpl = mindspore::Value;
using FuncGraphImpl = mindspore::FuncGraph;

MIND_API FuncGraphPtr CloneGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_impl = ToImpl<FuncGraphImpl>(func_graph);
Cloner cloner({fg_impl}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
auto cloned_fg = cloner[fg_impl];
return ToWrapper<api::FuncGraph>(cloned_fg);
}

int64_t GetPadMode(const api::ValuePtr &value, bool is_upper) {
int64_t result;
auto value_impl = ToImpl<ValueImpl>(value);
CheckAndConvertUtils::GetPadModEnumValue(value_impl, &result, is_upper);
return result;
}
} // namespace mindspore::api::utils

+ 94
- 0
mindspore/core/mindapi/src/value.cc View File

@@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindapi/ir/value.h"
#include "mindapi/ir/type.h"
#include "mindapi/ir/abstract.h"
#include "mindapi/src/helper.h"
#include "abstract/abstract_value.h"
#include "ir/anf.h"
#include "ir/dtype/type.h"
#include "ir/value.h"
#include "ir/scalar.h"

namespace mindspore::api {
using ValueImpl = mindspore::Value;
using ValueSequenceImpl = mindspore::ValueSequeue; // 'Sequeue' is typo.
using ValueTupleImpl = mindspore::ValueTuple;
using StringImmImpl = mindspore::StringImm;
using ScalarImpl = mindspore::Scalar;
using BoolImmImpl = mindspore::BoolImm;
using IntegerImmImpl = mindspore::IntergerImm; // 'Interger' is typo.
using Int64ImmImpl = mindspore::Int64Imm;
using FloatImmImpl = mindspore::FloatImm;
using FP32ImmImpl = mindspore::FP32Imm;

MIND_API_BASE_IMPL(Value, ValueImpl, Base);

TypePtr Value::type() const {
auto t = ToRef<ValueImpl>(impl_).type();
return ToWrapper<Type>(t);
}

AbstractBasePtr Value::ToAbstract() const {
auto abs = ToRef<ValueImpl>(impl_).ToAbstract();
return ToWrapper<AbstractBase>(abs);
}

MIND_API_BASE_IMPL(ValueSequence, ValueSequenceImpl, Value);

std::size_t ValueSequence::size() const { return ToRef<ValueSequenceImpl>(impl_).size(); }

std::vector<ValuePtr> ValueSequence::value() const {
auto &elements = ToRef<ValueSequenceImpl>(impl_).value();
return ToWrapperVector<Value>(elements);
}

MIND_API_BASE_IMPL(ValueTuple, ValueTupleImpl, ValueSequence);

ValueTuple::ValueTuple(const std::vector<ValuePtr> &elements)
: ValueSequence(std::make_shared<ValueTupleImpl>(ToImplVector<ValueImpl>(elements))) {}

MIND_API_BASE_IMPL(StringImm, StringImmImpl, Value);

StringImm::StringImm(const std::string &str) : Value(std::make_shared<StringImmImpl>(str)) {}

const std::string &StringImm::value() const { return ToRef<StringImmImpl>(impl_).value(); }

MIND_API_BASE_IMPL(Scalar, ScalarImpl, Value);

MIND_API_BASE_IMPL(BoolImm, BoolImmImpl, Scalar);

BoolImm::BoolImm(bool b) : Scalar(std::make_shared<BoolImmImpl>(b)) {}

bool BoolImm::value() const { return ToRef<BoolImmImpl>(impl_).value(); }

MIND_API_BASE_IMPL(IntegerImm, IntegerImmImpl, Scalar);

MIND_API_BASE_IMPL(Int64Imm, Int64ImmImpl, IntegerImm);

Int64Imm::Int64Imm(int64_t value) : IntegerImm(std::make_shared<Int64ImmImpl>(value)) {}

int64_t Int64Imm::value() const { return ToRef<Int64ImmImpl>(impl_).value(); }

MIND_API_BASE_IMPL(FloatImm, FloatImmImpl, Scalar);

MIND_API_BASE_IMPL(FP32Imm, FP32ImmImpl, FloatImm);

FP32Imm::FP32Imm(float value) : FloatImm(std::make_shared<FP32ImmImpl>(value)) {}

float FP32Imm::value() const { return ToRef<FP32ImmImpl>(impl_).value(); }
} // namespace mindspore::api

+ 1
- 0
mindspore/core/utils/log_adapter.h View File

@@ -141,6 +141,7 @@ enum SubModuleId : int {
SM_HCCL_ADPT, // Hccl Adapter
SM_RUNTIME_FRAMEWORK, // Runtime framework
SM_GE, // GraphEngine
SM_API, // MindAPI
NUM_SUBMODUES // number of submodules
};



+ 2
- 3
mindspore/core/utils/shape_utils.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
#ifndef MINDSPORE_SHAPE_UTILS_INFO_H_
#define MINDSPORE_SHAPE_UTILS_INFO_H_

#include <vector>
using ShapeVector = std::vector<int64_t>;
#include "mindapi/base/shape_vector.h"

#endif // MINDSPORE_SHAPE_UTILS_INFO_H_

+ 2
- 0
mindspore/lite/cmake/file_list.cmake View File

@@ -16,6 +16,8 @@ set(API_IR_HEADER
${CORE_DIR}/api/ir/func_graph.h
${CORE_DIR}/api/ir/func_graph_manager.h
)
file(GLOB MINDAPI_BASE_HEADER ${CORE_DIR}/mindapi/base/*.h)
file(GLOB MINDAPI_IR_HEADER ${CORE_DIR}/mindapi/ir/*.h)
set(BASE_HEADER
${CORE_DIR}/base/base.h
${CORE_DIR}/base/base_ref.h


+ 1
- 0
mindspore/lite/examples/train_lenet/Makefile View File

@@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \
-I . \
-I ./msl/runtime \
-I ./msl/runtime/include \
-I ./msl/runtime/minddata \
-I ./msl/tools/third_party/flatbuffers/include



+ 1
- 0
mindspore/lite/examples/transfer_learning/Makefile View File

@@ -15,6 +15,7 @@ OBJ:=$(SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \
-I . \
-I ./msl/runtime \
-I ./msl/runtime/include \
-I ./msl/runtime/minddata \
-I ./msl/tools/third_party/flatbuffers/include



+ 1
- 0
mindspore/lite/examples/unified_api/Makefile View File

@@ -19,6 +19,7 @@ INF_OBJ:=$(INF_SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \
-I . \
-I ./msl/runtime \
-I ./msl/runtime/include \
-I ./msl/runtime/minddata \
-I ./msl/tools/third_party/flatbuffers/include



+ 6
- 0
mindspore/lite/src/CMakeLists.txt View File

@@ -289,6 +289,9 @@ if(APPLE)
set(MINDSPORE_LITE_PUB_HDRS_IR_HDRS
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type_id.h
)
set(MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS
${CMAKE_CURRENT_SOURCE_DIR}/../../core/mindapi/base/type_id.h
)
add_library(mindspore-lite_static STATIC
${LITE_SRC}
${MINDSPORE_LITE_PUB_HDRS}
@@ -423,6 +426,9 @@ if(DEFINED ARCHS)
FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_IR_HDRS})
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/ir/dtype/)
ENDFOREACH()
FOREACH(HDR ${MINDSPORE_LITE_PUB_HDRS_MINDAPI_HDRS})
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/include/mindapi/base/)
ENDFOREACH()
target_link_libraries(mindspore-lite_static)
endif()



+ 1
- 0
tests/ut/cpp/CMakeLists.txt View File

@@ -71,6 +71,7 @@ if(ENABLE_MINDDATA)
./fl/*.cc
./cxx_api/*.cc
./tbe/*.cc
./mindapi/*.cc
)
if(NOT ENABLE_SECURITY)
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}


+ 394
- 0
tests/ut/cpp/mindapi/mindapi_test.cc View File

@@ -0,0 +1,394 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <memory>
#include <sstream>
#include <unordered_map>
#include "common/common_test.h"
#include "mindapi/base/logging.h"
#include "mindapi/ir/func_graph.h"
#include "mindapi/ir/tensor.h"
#include "mindapi/ir/utils.h"

namespace mindspore::api {
class TestMindApi : public UT::Common {
public:
TestMindApi() = default;
};

/// Feature: MindAPI
/// Description: test basic 'is()' 'cast()'
/// Expectation: is/cast works correctly.
TEST_F(TestMindApi, test_base_isa_cast) {
auto value_node = MakeShared<ValueNode>(MakeValue(0));
auto base = MakeShared<Base>(value_node->impl());
ASSERT_TRUE(base->isa<Base>());
ASSERT_TRUE(base->isa<AnfNode>());
ASSERT_TRUE(base->isa<ValueNode>());
ASSERT_FALSE(base->isa<AbstractBase>());
auto anf_node = base->cast<AnfNodePtr>();
ASSERT_TRUE(anf_node != nullptr);
ASSERT_TRUE(anf_node->impl() == value_node->impl());
ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr);
}

/// Feature: MindAPI
/// Description: test graph construction.
/// Expectation: graph is constructed as expected.
TEST_F(TestMindApi, test_graph_construction) {
// fg(x) { return myprim(x, 1); }
auto fg = FuncGraph::Create();
auto x = fg->add_parameter();
x->set_name("x");
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto value_node = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({prim_node, x, value_node});
fg->set_output(cnode);

// Now we check the graph.
ASSERT_EQ(fg->parameters().size(), 1);
ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>());
ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x");

auto ret_node = fg->get_return();
ASSERT_TRUE(ret_node != nullptr);
auto output_node = fg->output();
ASSERT_TRUE(output_node != nullptr);
ASSERT_TRUE(output_node->isa<CNode>());

auto output_cnode = output_node->cast<CNodePtr>();
ASSERT_EQ(output_cnode->inputs().size(), 3);
ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>());
ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>());
ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim");
ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>());
ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x");
ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>());

ASSERT_EQ(output_cnode->impl(), cnode->impl());
}

/// Feature: MindAPI
/// Description: test value related functions.
/// Expectation: value related functions work as expected.
TEST_F(TestMindApi, test_values) {
int64_t one = 1;
auto s = MakeValue("hello");
auto i = MakeValue(one);
auto i2 = MakeValue(2);
auto b = MakeValue(true);
auto f = MakeValue(3.14f);
auto seq = MakeValue(std::vector<int64_t>{3, 4, 5});
auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"}));

ASSERT_TRUE(s->isa<StringImm>());
ASSERT_TRUE(i->isa<Int64Imm>());
ASSERT_TRUE(i2->isa<Int64Imm>());
ASSERT_TRUE(b->isa<BoolImm>());
ASSERT_TRUE(f->isa<FP32Imm>());
ASSERT_TRUE(seq->isa<ValueSequence>());
ASSERT_TRUE(seq_str->isa<ValueSequence>());

ASSERT_EQ(GetValue<std::string>(s), "hello");
ASSERT_EQ(GetValue<int64_t>(i), one);
ASSERT_EQ(GetValue<int64_t>(i2), 2);
ASSERT_TRUE(GetValue<bool>(b));
ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f);

ASSERT_EQ(GetValue<std::string>(i), "");
ASSERT_EQ(GetValue<int64_t>(s), 0);
ASSERT_FALSE(GetValue<bool>(s));
ASSERT_EQ(GetValue<float>(s), 0.0f);

auto seq_ptr = seq->cast<ValueSequencePtr>();
ASSERT_TRUE(seq_ptr != nullptr);
ASSERT_EQ(seq_ptr->size(), 3);
ASSERT_EQ(seq_ptr->value().size(), 3);
ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>());
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3);
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4);
ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5);

auto seq_values = GetValue<std::vector<int64_t>>(seq);
ASSERT_EQ(seq_values.size(), 3);
ASSERT_EQ(seq_values[0], 3);
ASSERT_EQ(seq_values[1], 4);
ASSERT_EQ(seq_values[2], 5);

auto str_values = GetValue<std::vector<std::string>>(seq_str);
ASSERT_EQ(str_values.size(), 4);
ASSERT_EQ(str_values[0], "this");
ASSERT_EQ(str_values[1], "is");
ASSERT_EQ(str_values[2], "mindspore");
ASSERT_EQ(str_values[3], "api");
}

/// Feature: MindAPI
/// Description: test graph manager functions.
/// Expectation: graph manager functions work as expected.
TEST_F(TestMindApi, test_func_graph_manager) {
// fg(x, y) { return myprim(add(x, y), 1); }
auto fg = FuncGraph::Create();
auto x = fg->add_parameter();
x->set_name("x");
auto y = fg->add_parameter();
y->set_name("y");
auto add = MakeShared<Primitive>("add");
auto add_node = MakeShared<ValueNode>(add);
auto add_cnode = fg->NewCNode({add_node, x, y});
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto value_node = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
fg->set_output(cnode);

auto mgr = FuncGraphManager::Manage(fg);
ASSERT_TRUE(mgr != nullptr);
ASSERT_TRUE(fg->manager() != nullptr);
ASSERT_EQ(fg->manager()->impl(), mgr->impl());
ASSERT_EQ(fg->manager(), mgr);

ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl());
mgr->Replace(add_cnode, x);
ASSERT_EQ(cnode->input(1)->impl(), x->impl());

mgr->SetEdge(cnode, 1, y);
ASSERT_EQ(cnode->input(1)->impl(), y->impl());

mgr->AddEdge(cnode, x);
ASSERT_EQ(cnode->size(), 4);
ASSERT_EQ(cnode->input(3)->impl(), x->impl());

auto users = mgr->GetUsers(value_node);
ASSERT_EQ(users.size(), 1);
ASSERT_EQ(users[0].first, cnode);
ASSERT_EQ(users[0].second, 2);
}

/// Feature: MindAPI
/// Description: test value node utils.
/// Expectation: value node utils work as expected.
TEST_F(TestMindApi, test_value_node_utils) {
auto fg = FuncGraph::Create();
auto fg_node = MakeShared<ValueNode>(fg);
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto one = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({fg_node, prim_node, one});

ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr);

auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0));
ASSERT_TRUE(fg1 != nullptr);
ASSERT_TRUE(fg1->isa<FuncGraph>());

auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1));
ASSERT_TRUE(prim1 != nullptr);
ASSERT_TRUE(prim1->isa<Primitive>());

auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2));
ASSERT_TRUE(imm != nullptr);
ASSERT_TRUE(imm->isa<Int64Imm>());
ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1);

auto value = GetValueNode(cnode->input(2));
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(GetValue<int64_t>(value), 1);

ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr);
ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr);
ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr);

// Test NewValueNode.
auto int_node = NewValueNode(1);
auto bool_node = NewValueNode(true);
auto float_node = NewValueNode(1.23f);
auto str_node = NewValueNode("hello");

ASSERT_TRUE(int_node->value()->isa<Int64Imm>());
ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1);
ASSERT_TRUE(bool_node->value()->isa<BoolImm>());
ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value());
ASSERT_TRUE(float_node->value()->isa<FP32Imm>());
ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f);
ASSERT_TRUE(str_node->value()->isa<StringImm>());
ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello");
}

/// Feature: MindAPI
/// Description: test SharedPtr.
/// Expectation: SharedPtr work as expected.
TEST_F(TestMindApi, test_object_ptr) {
auto fg = FuncGraph::Create();
auto fg_node = MakeShared<ValueNode>(fg);
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto one = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({fg_node, prim_node, one});

ASSERT_TRUE(fg != nullptr);
ASSERT_FALSE(!fg);
ASSERT_TRUE(fg ? true : false);
ASSERT_TRUE((*cnode).input(0) == fg_node);
ASSERT_TRUE(cnode->input(0) == fg_node);
ASSERT_TRUE(cnode.get()->input(0) == fg_node);

ASSERT_EQ(cnode->input(0), fg_node);
ASSERT_EQ(cnode->input(1), prim_node);
ASSERT_EQ(cnode->input(2), one);
ASSERT_TRUE(cnode->input(0) != fg);

AnfNodePtr p = fg_node;
ASSERT_TRUE(p == fg_node);
ASSERT_TRUE(p->isa<ValueNode>());
ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr);
ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node);

p = cnode;
ASSERT_TRUE(p == cnode);
ASSERT_TRUE(p->isa<CNode>());
ASSERT_TRUE(p->cast<CNodePtr>() != nullptr);
ASSERT_TRUE(p->cast<CNodePtr>() == cnode);
ASSERT_TRUE(p.get() == cnode.get());

ASSERT_TRUE(p != nullptr);
ASSERT_FALSE(p == nullptr);
ASSERT_TRUE(p > nullptr);
ASSERT_FALSE(p < nullptr);
ASSERT_TRUE(p >= nullptr);
ASSERT_FALSE(p <= nullptr);

ASSERT_TRUE(nullptr != p);
ASSERT_FALSE(nullptr == p);
ASSERT_TRUE(nullptr < p);
ASSERT_FALSE(nullptr > p);
ASSERT_TRUE(nullptr <= p);
ASSERT_FALSE(nullptr >= p);

AnfNodePtr q = fg_node;
ASSERT_TRUE(p != q);
ASSERT_TRUE(p > q);
if (p.get()->impl() > q.get()->impl()) {
ASSERT_TRUE(p > q);
ASSERT_TRUE(p >= q);
ASSERT_TRUE(q < p);
ASSERT_TRUE(q <= p);
} else {
ASSERT_TRUE(p < q);
ASSERT_TRUE(p <= q);
ASSERT_TRUE(q > p);
ASSERT_TRUE(q >= p);
}

std::stringstream ss1;
std::stringstream ss2;
ss1 << p;
ss2 << cnode.get()->impl().get();
ASSERT_EQ(ss1.str(), ss2.str());

std::unordered_map<AnfNodePtr, AnfNodePtr> mymap;
mymap.emplace(p, q);
mymap.emplace(q, p);
ASSERT_TRUE(mymap.find(p) != mymap.end());
ASSERT_TRUE(mymap.find(q) != mymap.end());
ASSERT_TRUE(mymap[p] == q);
ASSERT_TRUE(mymap[q] == p);
}

/// Feature: MindAPI
/// Description: test Tensor API.
/// Expectation: Tensor API work as expected.
TEST_F(TestMindApi, test_tensor_api) {
ShapeVector shape{1, 2, 3};
auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);

ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
ASSERT_EQ(tensor->shape(), shape);
ASSERT_EQ(tensor->DataSize(), 6);
ASSERT_EQ(tensor->Size(), 24);

ShapeVector shape2{2, 3};
tensor->set_data_type(kNumberTypeInt32);
tensor->set_shape(shape2);
ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
ASSERT_EQ(tensor->shape(), shape2);
}

/// Feature: MindAPI
/// Description: test utils API.
/// Expectation: Tensor API work as expected.
TEST_F(TestMindApi, test_api_utils) {
// Test utils::isa, utils::cast.
auto anf_node = NewValueNode("hello");
ASSERT_TRUE(utils::isa<AnfNode>(anf_node));
ASSERT_FALSE(utils::isa<AbstractBase>(anf_node));
ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr);
ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr);

anf_node = nullptr;
ASSERT_FALSE(utils::isa<AnfNode>(anf_node));
ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr);

// Test clone graph.
auto fg = FuncGraph::Create();
auto x = fg->add_parameter();
x->set_name("x");
auto y = fg->add_parameter();
y->set_name("y");
auto add = MakeShared<Primitive>("add");
auto add_node = MakeShared<ValueNode>(add);
auto add_cnode = fg->NewCNode({add_node, x, y});
auto prim = MakeShared<Primitive>("myprim");
auto prim_node = MakeShared<ValueNode>(prim);
auto value_node = MakeShared<ValueNode>(MakeValue(1));
auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
fg->set_output(cnode);

auto cloned_fg = utils::CloneGraph(fg);
ASSERT_TRUE(cloned_fg != nullptr);
ASSERT_EQ(cloned_fg->parameters().size(), 2);
auto new_output = cloned_fg->output();
ASSERT_TRUE(new_output != nullptr);
ASSERT_TRUE(new_output->isa<CNode>());
ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size());
ASSERT_TRUE(new_output != cnode);
ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode);

// Test get pad mode.
auto pm_lower = MakeValue("pad");
auto pm_upper = MakeValue("PAD");
ASSERT_EQ(utils::GetPadMode(pm_lower), 0);
ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0);
ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0);
}

/// Feature: MindAPI
/// Description: test logging API.
/// Expectation: logging work as expected.
TEST_F(TestMindApi, test_api_logging) {
MS_LOG(DEBUG) << "hello debug";
MS_LOG(INFO) << "hello info";
MS_LOG(WARNING) << "hello warning";
MS_LOG(ERROR) << "hello error";
try {
MS_LOG(EXCEPTION) << "hello exception";
ASSERT_TRUE(false);
} catch (...) {
}
ASSERT_TRUE(true);
}
} // namespace mindspore::api

Loading…
Cancel
Save