Browse Source

[ME] Code static check.

feature/build-system-rewrite
Margaret_wangrui 4 years ago
parent
commit
e6d7eb2821
17 changed files with 86 additions and 93 deletions
  1. +26
    -19
      mindspore/core/abstract/prim_nn.cc
  2. +2
    -2
      mindspore/core/base/complex_storage.h
  3. +3
    -1
      mindspore/core/base/effect_info.h
  4. +10
    -10
      mindspore/core/base/float16.h
  5. +10
    -14
      mindspore/core/ir/anf.h
  6. +7
    -6
      mindspore/core/ir/device_sync.h
  7. +1
    -2
      mindspore/core/ir/dtype/container.cc
  8. +1
    -2
      mindspore/core/ir/dtype/number.cc
  9. +1
    -2
      mindspore/core/ir/dtype/ref.cc
  10. +1
    -2
      mindspore/core/ir/dtype/tensor_type.cc
  11. +1
    -2
      mindspore/core/ir/dtype/type.cc
  12. +1
    -2
      mindspore/core/ir/dtype_extends.cc
  13. +9
    -8
      mindspore/core/ir/func_graph.cc
  14. +7
    -12
      mindspore/core/ir/func_graph.h
  15. +2
    -2
      mindspore/core/ir/func_graph_extends.cc
  16. +3
    -1
      mindspore/core/ir/kernel_info_dev.h
  17. +1
    -6
      mindspore/core/ops/base_operator.h

+ 26
- 19
mindspore/core/abstract/prim_nn.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,7 +32,9 @@ const size_t padding_start_idx = 0;
int64_t GetAndCheckFormat(const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
if (!result ||
(data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
data_format != static_cast<int64_t>(Format::NCDHW))) {
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW";
}
return data_format;
@@ -80,11 +82,12 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
auto pad_mode_ptr = primitive->GetAttr("pad_mode");
if (pad_mode_ptr != nullptr) {
int64_t pad_mode;
const size_t middle = 2;
CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
if (pad_mode == PadMode::VALID) {
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) {
padding = 0;
} else if (pad_mode == PadMode::SAME) {
padding = (window - 1) / 2;
} else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) {
padding = (window - 1) / middle;
}
}
std::set<std::string> available_mode{"max", "avg"};
@@ -95,9 +98,9 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << ".";
}
}
int64_t h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1;
int64_t w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1;
const size_t twice = 2;
int64_t h_out = (((h_input + twice * padding - (window - 1)) - 1) / stride) + 1;
int64_t w_out = (((w_input + twice * padding - (window - 1)) - 1) / stride) + 1;
ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out};
AbstractBasePtr ret = input_tensor->Broaden();
ret->set_shape(std::make_shared<Shape>(shape_out));
@@ -153,7 +156,7 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
int64_t data_format = GetAndCheckFormat(data_format_ptr);

size_t c_axis = 1;
if (data_format == Format::NHWC) {
if (data_format == static_cast<int64_t>(Format::NHWC)) {
c_axis = 3;
}
for (size_t i = 1; i < args_spec_list.size(); ++i) {
@@ -204,30 +207,34 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
const std::vector<int64_t> &dilation, const int64_t &pad_mode,
const std::vector<int64_t> &padding) {
if (pad_mode == PadMode::VALID) {
const size_t middle = 2;
const size_t second_index = 2;
const size_t third_index = 3;
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) {
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])));
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])));
const size_t nhwc = 4;
(void)pad_list->insert(pad_list->begin(), nhwc, 0);
} else if (pad_mode == PadMode::SAME) {
} else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) {
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
pad_needed_h = std::max((int64_t)0, pad_needed_h);
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / middle)));
pad_list->push_back(pad_needed_h - pad_list->at(0));
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
pad_needed_w = std::max((int64_t)0, pad_needed_w);
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
pad_list->push_back(pad_needed_w - pad_list->at(2));
} else if (pad_mode == PadMode::PAD) {
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / middle)));
pad_list->push_back(pad_needed_w - pad_list->at(middle));
} else if (pad_mode == static_cast<int64_t>(PadMode::PAD)) {
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
output_hw->push_back(static_cast<int64_t>(std::floor(
1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) /
1 + (((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0]) - (kernel[0] - 1) * (dilation[0] - 1)) /
stride[0])));
output_hw->push_back(static_cast<int64_t>(std::floor(
1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
stride[1])));
output_hw->push_back(static_cast<int64_t>(
std::floor(1 + (((x_w * 1.0) + pad_list->at(second_index) + pad_list->at(third_index) - kernel[1]) -
(kernel[1] - 1) * (dilation[1] - 1)) /
stride[1])));
}
}



+ 2
- 2
mindspore/core/base/complex_storage.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -62,7 +62,7 @@ struct alignas(sizeof(T) * 2) ComplexStorage {

template <typename T>
inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) {
return lhs.real_ == rhs.real_ && lhs.imag_ == rhs.imag_;
return (lhs.real_ - rhs.real_ == 0) && (lhs.imag_ - rhs.imag_ == 0);
}

template <typename T>


+ 3
- 1
mindspore/core/base/effect_info.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -52,6 +52,8 @@ class EffectInfoHolder {
// Unset effect info.
void UnsetEffectInfo() { effect_info_ = {EffectInfo::kUnknown, false, false}; }

~EffectInfoHolder() {}

protected:
EffectInfo effect_info_;
};


+ 10
- 10
mindspore/core/base/float16.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -98,8 +98,8 @@ class Float16 {
return *this;
}

static float ToFloat32(Float16 f16) {
constexpr Union32 magic = {113 << 23};
static float ToFloat32(const Float16 &f16) {
constexpr Union32 magic = {.u = 113 << 23};
constexpr uint32_t exponent_adjust = ((127 - 15) << 23);
constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23);
constexpr uint32_t zero_extra_exp_adjust = (1 << 23);
@@ -130,9 +130,9 @@ class Float16 {
private:
static uint16_t FromFloat32(float f32) {
constexpr uint32_t magic = {113 << 23};
constexpr Union32 f32infty = {255 << 23};
constexpr Union32 f16max = {(127 + 16) << 23};
constexpr Union32 denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
constexpr Union32 f32infty = {.u = 255 << 23};
constexpr Union32 f16max = {.u = (127 + 16) << 23};
constexpr Union32 denorm_magic = {.u = ((127 - 15) + (23 - 10) + 1) << 23};
constexpr unsigned int exponent_bits = 13;
constexpr unsigned int sign_bit_shift = 16;
constexpr unsigned int sign_mask = 0x80000000u;
@@ -275,11 +275,11 @@ struct numeric_limits<float16> {
// std::numeric_limits<const volatile T>
// https://stackoverflow.com/a/16519653/
template <>
struct numeric_limits<const mindspore::Float16> : numeric_limits<mindspore::Float16> {};
struct numeric_limits<const mindspore::Float16> : private numeric_limits<mindspore::Float16> {};
template <>
struct numeric_limits<volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
struct numeric_limits<volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {};
template <>
struct numeric_limits<const volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
struct numeric_limits<const volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {};
} // namespace std

// Implements standard math functions for float16.
@@ -306,6 +306,6 @@ inline float16 pow(const float16 &a, const float16 &b) {

#endif // ENABLE_ARM32 || ENABLE_ARM64

inline float half_to_float(float16 h) { return static_cast<float>(h); }
inline float half_to_float(const float16 &h) { return static_cast<float>(h); }

#endif // MINDSPORE_CORE_BASE_FLOAT16_H_

+ 10
- 14
mindspore/core/ir/anf.h View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -168,7 +168,7 @@ class MS_CORE_API AnfNode : public Base {
/// \brief Obtain the pointer of KernelInfoDevice.
///
/// \return The pointer of KernelInfoDevice.
const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
const KernelInfoDevicePtr &kernel_info_ptr() const { return kernel_info_; }

/// \brief Set device kernel program information.
///
@@ -340,7 +340,7 @@ class MS_CORE_API AnfNode : public Base {
/// \brief Check if there is an interpret node.
///
/// \return True if there is an interpret node, otherwise false.
bool interpret() { return interpret_; }
bool interpret() const { return interpret_; }

/// \brief Whether to use interpretation
///
@@ -558,7 +558,7 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \brief Check if is_load_ is set.
///
/// \return True if is_load_ is set, otherwise false.
bool get_load_flag() { return is_load_; }
bool get_load_flag() const { return is_load_; }

/// \brief Get func_graph_as_var of this CNode.
///
@@ -805,7 +805,7 @@ class MS_CORE_API Parameter final : public ANode {
/// \brief Set the default parameter.
///
/// \param[in] param The default parameter.
void set_default_param(ValuePtr param) {
void set_default_param(const ValuePtr &param) {
default_param_ = param;
has_default_ = true;
}
@@ -876,7 +876,7 @@ class MS_CORE_API Parameter final : public ANode {
/// \brief Get groups attr in FracZ format.
///
/// \return Groups attr in FracZ format.
int64_t fracz_group() { return format_attrs_.fracz_group; }
int64_t fracz_group() const { return format_attrs_.fracz_group; }

/// \brief Set input_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
@@ -886,7 +886,7 @@ class MS_CORE_API Parameter final : public ANode {
/// \brief Get input_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
/// \return input_size attr in FracNZ_RNN or ND_RNN_Bias format.
int64_t input_size() { return format_attrs_.input_size; }
int64_t input_size() const { return format_attrs_.input_size; }

/// \brief Set hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
@@ -896,7 +896,7 @@ class MS_CORE_API Parameter final : public ANode {
/// \brief Get hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
///
/// \return hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
int64_t hidden_size() { return format_attrs_.hidden_size; }
int64_t hidden_size() const { return format_attrs_.hidden_size; }

private:
struct FormatAttr {
@@ -946,11 +946,7 @@ class MS_CORE_API Value : public Base {
/// \brief Get the abstract value of Value.
///
/// \return Abstract value of Value.
virtual abstract::AbstractBasePtr ToAbstract() {
MS_LOG(EXCEPTION) << "ToAbstract error";
abstract::AbstractBasePtr result;
return result;
}
virtual abstract::AbstractBasePtr ToAbstract() { MS_LOG(EXCEPTION) << "ToAbstract error"; }

/// \brief Check whether the input is the current Value object.
///
@@ -994,7 +990,7 @@ class MS_CORE_API ValueNode final : public ANode {
~ValueNode() override = default;
MS_DECLARE_PARENT(ValueNode, ANode);

void set_func_graph(const FuncGraphPtr &func_graph) override {
void set_func_graph(const FuncGraphPtr &) override {
MS_EXCEPTION(ValueError) << "ValueNode should not set its func_graph.";
}



+ 7
- 6
mindspore/core/ir/device_sync.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,16 +32,15 @@ class DeviceSync {
public:
// Used to sync data between different device addresses, only need the data size and data ptr. The CPU device doesn't
// need use the interfaces, so need the default implementation.
virtual bool SyncDeviceToHost(size_t size, void *host_ptr) const { return true; }
virtual bool SyncHostToDevice(size_t size, const void *host_ptr) const { return true; }
virtual bool SyncDeviceToHost(size_t, void *) const { return true; }
virtual bool SyncHostToDevice(size_t, const void *) const { return true; }

// Used to sync data between host tensor and device address, additional need the data shape and data type.
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const = 0;
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format = "DefaultFormat") const = 0;
virtual bool SyncDeviceToDevice(const DeviceSync *src_device_addr) const { return true; }
virtual bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const {
virtual bool SyncDeviceToDevice(const DeviceSync *) const { return true; }
virtual bool AsyncDeviceToDevice(const ShapeVector &, size_t, TypeId type, const void *, const std::string &) const {
return true;
}

@@ -66,6 +65,8 @@ class DeviceSync {
void DecreaseRefCount() { ref_count_--; }
void ResetRefCount() { ref_count_ = original_ref_count_; }

virtual ~DeviceSync() {}

protected:
mutable size_t original_ref_count_{1};
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.


+ 1
- 2
mindspore/core/ir/dtype/container.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
*/

#include "ir/dtype/container.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"


+ 1
- 2
mindspore/core/ir/dtype/number.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
*/

#include "ir/dtype/number.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"


+ 1
- 2
mindspore/core/ir/dtype/ref.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
*/

#include "ir/dtype/ref.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"


+ 1
- 2
mindspore/core/ir/dtype/tensor_type.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
*/

#include "ir/dtype/tensor_type.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"


+ 1
- 2
mindspore/core/ir/dtype/type.cc View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@

#include <algorithm>
#include <cstdlib>
#include <string>
#include <climits>

#include "ir/dtype/number.h"


+ 1
- 2
mindspore/core/ir/dtype_extends.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
*/

#include "ir/dtype.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"


+ 9
- 8
mindspore/core/ir/func_graph.cc View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -578,14 +578,14 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
}
return parameters_[parameters_.size() - hyper_param_count_ - min_param_num];
return parameters_[(parameters_.size() - hyper_param_count_) - min_param_num];
}

if (parameters_.size() < hyper_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
}
return parameters_[parameters_.size() - hyper_param_count_ - 1];
return parameters_[(parameters_.size() - hyper_param_count_) - 1];
}

std::string FuncGraph::GetVariableArgName() {
@@ -600,7 +600,8 @@ std::string FuncGraph::GetVariableArgName() {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
}
const auto &parameter = parameters_[parameters_.size() - hyper_param_count_ - min_param_num]->cast<ParameterPtr>();
const auto &parameter =
parameters_[(parameters_.size() - hyper_param_count_) - min_param_num]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@@ -609,7 +610,7 @@ std::string FuncGraph::GetVariableArgName() {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
}
const auto &parameter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>();
const auto &parameter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@@ -620,7 +621,7 @@ AnfNodePtr FuncGraph::GetVariableKwargParameter() {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
}
return parameters_[parameters_.size() - hyper_param_count_ - 1];
return parameters_[(parameters_.size() - hyper_param_count_) - 1];
}
return nullptr;
}
@@ -631,7 +632,7 @@ std::string FuncGraph::GetVariableKwargName() {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
}
const auto &parameter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>();
const auto &parameter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@@ -646,7 +647,7 @@ int FuncGraph::GetPositionalArgsCount() const {
if (has_vararg_) {
count--;
}
return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
return (count - kwonlyargs_count_) - SizeToInt(hyper_param_count_);
}

AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {


+ 7
- 12
mindspore/core/ir/func_graph.h View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -95,13 +95,6 @@ const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad";
const char kFuncGraphFlagRecursive[] = "Recursive";

namespace abstract {
class AbstractKeywordArg;
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
class AbstractFunction;
using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
} // namespace abstract

class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
public:
using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
@@ -348,14 +341,16 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi

bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
static void set_drawer(const Drawer &drawer) { drawer_ = drawer; }
std::shared_ptr<bool> switch_input() const { return switch_input_; }
void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; }
void set_switch_input(const std::shared_ptr<bool> &switch_input) { switch_input_ = switch_input; }
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
void set_switch_layer_input(const std::shared_ptr<bool> &switch_layer_input) {
switch_layer_input_ = switch_layer_input;
}
bool ContainMultiTarget();
bool IsMultiTarget() const { return exist_multi_target_; }
int64_t stage() { return stage_; }
int64_t stage() const { return stage_; }
void set_stage(int64_t stage) { stage_ = stage; }

bool dropped() const { return dropped_; }


+ 2
- 2
mindspore/core/ir/func_graph_extends.cc View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -252,7 +252,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
}
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - hyper_param_count_);
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
int variable_args_count = pos_args_input_count - pos_args_count;
std::vector<AnfNodePtr> specialized_parameter_list;


+ 3
- 1
mindspore/core/ir/kernel_info_dev.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -85,6 +85,8 @@ class KernelInfoDevice {

RuntimeCacheScope runtime_cache() { return RuntimeCacheScope(runtime_cache_, mu_); }

virtual ~KernelInfoDevice() {}

private:
RuntimeCache runtime_cache_;
std::mutex mu_;


+ 1
- 6
mindspore/core/ops/base_operator.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -31,11 +31,6 @@ class AbstractBase;
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
} // namespace abstract

namespace mindspore {
class Primitive;
using PrimitivePtr = std::shared_ptr<Primitive>;
} // namespace mindspore

namespace mindspore {
namespace ops {
class BaseOperator : public api::Primitive {


Loading…
Cancel
Save