| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -32,7 +32,9 @@ const size_t padding_start_idx = 0; | |||
| int64_t GetAndCheckFormat(const ValuePtr &value) { | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | |||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) { | |||
| if (!result || | |||
| (data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) && | |||
| data_format != static_cast<int64_t>(Format::NCDHW))) { | |||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW"; | |||
| } | |||
| return data_format; | |||
| @@ -80,11 +82,12 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| auto pad_mode_ptr = primitive->GetAttr("pad_mode"); | |||
| if (pad_mode_ptr != nullptr) { | |||
| int64_t pad_mode; | |||
| const size_t middle = 2; | |||
| CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true); | |||
| if (pad_mode == PadMode::VALID) { | |||
| if (pad_mode == static_cast<int64_t>(PadMode::VALID)) { | |||
| padding = 0; | |||
| } else if (pad_mode == PadMode::SAME) { | |||
| padding = (window - 1) / 2; | |||
| } else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) { | |||
| padding = (window - 1) / middle; | |||
| } | |||
| } | |||
| std::set<std::string> available_mode{"max", "avg"}; | |||
| @@ -95,9 +98,9 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; | |||
| } | |||
| } | |||
| int64_t h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1; | |||
| int64_t w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1; | |||
| const size_t twice = 2; | |||
| int64_t h_out = (((h_input + twice * padding - (window - 1)) - 1) / stride) + 1; | |||
| int64_t w_out = (((w_input + twice * padding - (window - 1)) - 1) / stride) + 1; | |||
| ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; | |||
| AbstractBasePtr ret = input_tensor->Broaden(); | |||
| ret->set_shape(std::make_shared<Shape>(shape_out)); | |||
| @@ -153,7 +156,7 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr | |||
| int64_t data_format = GetAndCheckFormat(data_format_ptr); | |||
| size_t c_axis = 1; | |||
| if (data_format == Format::NHWC) { | |||
| if (data_format == static_cast<int64_t>(Format::NHWC)) { | |||
| c_axis = 3; | |||
| } | |||
| for (size_t i = 1; i < args_spec_list.size(); ++i) { | |||
| @@ -204,30 +207,34 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa | |||
| const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride, | |||
| const std::vector<int64_t> &dilation, const int64_t &pad_mode, | |||
| const std::vector<int64_t> &padding) { | |||
| if (pad_mode == PadMode::VALID) { | |||
| const size_t middle = 2; | |||
| const size_t second_index = 2; | |||
| const size_t third_index = 3; | |||
| if (pad_mode == static_cast<int64_t>(PadMode::VALID)) { | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]))); | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]))); | |||
| const size_t nhwc = 4; | |||
| (void)pad_list->insert(pad_list->begin(), nhwc, 0); | |||
| } else if (pad_mode == PadMode::SAME) { | |||
| } else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) { | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0]))); | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1]))); | |||
| int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; | |||
| pad_needed_h = std::max((int64_t)0, pad_needed_h); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2))); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / middle))); | |||
| pad_list->push_back(pad_needed_h - pad_list->at(0)); | |||
| int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; | |||
| pad_needed_w = std::max((int64_t)0, pad_needed_w); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2))); | |||
| pad_list->push_back(pad_needed_w - pad_list->at(2)); | |||
| } else if (pad_mode == PadMode::PAD) { | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / middle))); | |||
| pad_list->push_back(pad_needed_w - pad_list->at(middle)); | |||
| } else if (pad_mode == static_cast<int64_t>(PadMode::PAD)) { | |||
| (void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); | |||
| output_hw->push_back(static_cast<int64_t>(std::floor( | |||
| 1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / | |||
| 1 + (((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0]) - (kernel[0] - 1) * (dilation[0] - 1)) / | |||
| stride[0]))); | |||
| output_hw->push_back(static_cast<int64_t>(std::floor( | |||
| 1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / | |||
| stride[1]))); | |||
| output_hw->push_back(static_cast<int64_t>( | |||
| std::floor(1 + (((x_w * 1.0) + pad_list->at(second_index) + pad_list->at(third_index) - kernel[1]) - | |||
| (kernel[1] - 1) * (dilation[1] - 1)) / | |||
| stride[1]))); | |||
| } | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2021-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -62,7 +62,7 @@ struct alignas(sizeof(T) * 2) ComplexStorage { | |||
| template <typename T> | |||
| inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) { | |||
| return lhs.real_ == rhs.real_ && lhs.imag_ == rhs.imag_; | |||
| return (lhs.real_ - rhs.real_ == 0) && (lhs.imag_ - rhs.imag_ == 0); | |||
| } | |||
| template <typename T> | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -52,6 +52,8 @@ class EffectInfoHolder { | |||
| // Unset effect info. | |||
| void UnsetEffectInfo() { effect_info_ = {EffectInfo::kUnknown, false, false}; } | |||
| ~EffectInfoHolder() {} | |||
| protected: | |||
| EffectInfo effect_info_; | |||
| }; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -98,8 +98,8 @@ class Float16 { | |||
| return *this; | |||
| } | |||
| static float ToFloat32(Float16 f16) { | |||
| constexpr Union32 magic = {113 << 23}; | |||
| static float ToFloat32(const Float16 &f16) { | |||
| constexpr Union32 magic = {.u = 113 << 23}; | |||
| constexpr uint32_t exponent_adjust = ((127 - 15) << 23); | |||
| constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23); | |||
| constexpr uint32_t zero_extra_exp_adjust = (1 << 23); | |||
| @@ -130,9 +130,9 @@ class Float16 { | |||
| private: | |||
| static uint16_t FromFloat32(float f32) { | |||
| constexpr uint32_t magic = {113 << 23}; | |||
| constexpr Union32 f32infty = {255 << 23}; | |||
| constexpr Union32 f16max = {(127 + 16) << 23}; | |||
| constexpr Union32 denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; | |||
| constexpr Union32 f32infty = {.u = 255 << 23}; | |||
| constexpr Union32 f16max = {.u = (127 + 16) << 23}; | |||
| constexpr Union32 denorm_magic = {.u = ((127 - 15) + (23 - 10) + 1) << 23}; | |||
| constexpr unsigned int exponent_bits = 13; | |||
| constexpr unsigned int sign_bit_shift = 16; | |||
| constexpr unsigned int sign_mask = 0x80000000u; | |||
| @@ -275,11 +275,11 @@ struct numeric_limits<float16> { | |||
| // std::numeric_limits<const volatile T> | |||
| // https://stackoverflow.com/a/16519653/ | |||
| template <> | |||
| struct numeric_limits<const mindspore::Float16> : numeric_limits<mindspore::Float16> {}; | |||
| struct numeric_limits<const mindspore::Float16> : private numeric_limits<mindspore::Float16> {}; | |||
| template <> | |||
| struct numeric_limits<volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {}; | |||
| struct numeric_limits<volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {}; | |||
| template <> | |||
| struct numeric_limits<const volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {}; | |||
| struct numeric_limits<const volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {}; | |||
| } // namespace std | |||
| // Implements standard math functions for float16. | |||
| @@ -306,6 +306,6 @@ inline float16 pow(const float16 &a, const float16 &b) { | |||
| #endif // ENABLE_ARM32 || ENABLE_ARM64 | |||
| inline float half_to_float(float16 h) { return static_cast<float>(h); } | |||
| inline float half_to_float(const float16 &h) { return static_cast<float>(h); } | |||
| #endif // MINDSPORE_CORE_BASE_FLOAT16_H_ | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -168,7 +168,7 @@ class MS_CORE_API AnfNode : public Base { | |||
| /// \brief Obtain the pointer of KernelInfoDevice. | |||
| /// | |||
| /// \return The pointer of KernelInfoDevice. | |||
| const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; } | |||
| const KernelInfoDevicePtr &kernel_info_ptr() const { return kernel_info_; } | |||
| /// \brief Set device kernel program information. | |||
| /// | |||
| @@ -340,7 +340,7 @@ class MS_CORE_API AnfNode : public Base { | |||
| /// \brief Check if there is an interpret node. | |||
| /// | |||
| /// \return True if there is an interpret node, otherwise false. | |||
| bool interpret() { return interpret_; } | |||
| bool interpret() const { return interpret_; } | |||
| /// \brief Whether to use interpretation | |||
| /// | |||
| @@ -558,7 +558,7 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder { | |||
| /// \brief Check if is_load_ is set. | |||
| /// | |||
| /// \return True if is_load_ is set, otherwise false. | |||
| bool get_load_flag() { return is_load_; } | |||
| bool get_load_flag() const { return is_load_; } | |||
| /// \brief Get func_graph_as_var of this CNode. | |||
| /// | |||
| @@ -805,7 +805,7 @@ class MS_CORE_API Parameter final : public ANode { | |||
| /// \brief Set the default parameter. | |||
| /// | |||
| /// \param[in] param The default parameter. | |||
| void set_default_param(ValuePtr param) { | |||
| void set_default_param(const ValuePtr ¶m) { | |||
| default_param_ = param; | |||
| has_default_ = true; | |||
| } | |||
| @@ -876,7 +876,7 @@ class MS_CORE_API Parameter final : public ANode { | |||
| /// \brief Get groups attr in FracZ format. | |||
| /// | |||
| /// \return Groups attr in FracZ format. | |||
| int64_t fracz_group() { return format_attrs_.fracz_group; } | |||
| int64_t fracz_group() const { return format_attrs_.fracz_group; } | |||
| /// \brief Set input_size attr in FracNZ_RNN or ND_RNN_Bias format. | |||
| /// | |||
| @@ -886,7 +886,7 @@ class MS_CORE_API Parameter final : public ANode { | |||
| /// \brief Get input_size attr in FracNZ_RNN or ND_RNN_Bias format. | |||
| /// | |||
| /// \return input_size attr in FracNZ_RNN or ND_RNN_Bias format. | |||
| int64_t input_size() { return format_attrs_.input_size; } | |||
| int64_t input_size() const { return format_attrs_.input_size; } | |||
| /// \brief Set hidden_size attr in FracNZ_RNN or ND_RNN_Bias format. | |||
| /// | |||
| @@ -896,7 +896,7 @@ class MS_CORE_API Parameter final : public ANode { | |||
| /// \brief Get hidden_size attr in FracNZ_RNN or ND_RNN_Bias format. | |||
| /// | |||
| /// \return hidden_size attr in FracNZ_RNN or ND_RNN_Bias format. | |||
| int64_t hidden_size() { return format_attrs_.hidden_size; } | |||
| int64_t hidden_size() const { return format_attrs_.hidden_size; } | |||
| private: | |||
| struct FormatAttr { | |||
| @@ -946,11 +946,7 @@ class MS_CORE_API Value : public Base { | |||
| /// \brief Get the abstract value of Value. | |||
| /// | |||
| /// \return Abstract value of Value. | |||
| virtual abstract::AbstractBasePtr ToAbstract() { | |||
| MS_LOG(EXCEPTION) << "ToAbstract error"; | |||
| abstract::AbstractBasePtr result; | |||
| return result; | |||
| } | |||
| virtual abstract::AbstractBasePtr ToAbstract() { MS_LOG(EXCEPTION) << "ToAbstract error"; } | |||
| /// \brief Check whether the input is the current Value object. | |||
| /// | |||
| @@ -994,7 +990,7 @@ class MS_CORE_API ValueNode final : public ANode { | |||
| ~ValueNode() override = default; | |||
| MS_DECLARE_PARENT(ValueNode, ANode); | |||
| void set_func_graph(const FuncGraphPtr &func_graph) override { | |||
| void set_func_graph(const FuncGraphPtr &) override { | |||
| MS_EXCEPTION(ValueError) << "ValueNode should not set its func_graph."; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -32,16 +32,15 @@ class DeviceSync { | |||
| public: | |||
| // Used to sync data between different device addresses, only need the data size and data ptr. The CPU device doesn't | |||
| // need use the interfaces, so need the default implementation. | |||
| virtual bool SyncDeviceToHost(size_t size, void *host_ptr) const { return true; } | |||
| virtual bool SyncHostToDevice(size_t size, const void *host_ptr) const { return true; } | |||
| virtual bool SyncDeviceToHost(size_t, void *) const { return true; } | |||
| virtual bool SyncHostToDevice(size_t, const void *) const { return true; } | |||
| // Used to sync data between host tensor and device address, additional need the data shape and data type. | |||
| virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const = 0; | |||
| virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, | |||
| const std::string &format = "DefaultFormat") const = 0; | |||
| virtual bool SyncDeviceToDevice(const DeviceSync *src_device_addr) const { return true; } | |||
| virtual bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, | |||
| const std::string &format) const { | |||
| virtual bool SyncDeviceToDevice(const DeviceSync *) const { return true; } | |||
| virtual bool AsyncDeviceToDevice(const ShapeVector &, size_t, TypeId type, const void *, const std::string &) const { | |||
| return true; | |||
| } | |||
| @@ -66,6 +65,8 @@ class DeviceSync { | |||
| void DecreaseRefCount() { ref_count_--; } | |||
| void ResetRefCount() { ref_count_ = original_ref_count_; } | |||
| virtual ~DeviceSync() {} | |||
| protected: | |||
| mutable size_t original_ref_count_{1}; | |||
| // It will be decreased in the running, and reset by original_ref_count_ when it is zero. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "ir/dtype/container.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "ir/dtype/number.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "ir/dtype/ref.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "ir/dtype/tensor_type.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,7 +20,6 @@ | |||
| #include <algorithm> | |||
| #include <cstdlib> | |||
| #include <string> | |||
| #include <climits> | |||
| #include "ir/dtype/number.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "ir/dtype.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -578,14 +578,14 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() { | |||
| MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " | |||
| << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; | |||
| } | |||
| return parameters_[parameters_.size() - hyper_param_count_ - min_param_num]; | |||
| return parameters_[(parameters_.size() - hyper_param_count_) - min_param_num]; | |||
| } | |||
| if (parameters_.size() < hyper_param_count_ + 1) { | |||
| MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " | |||
| << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; | |||
| } | |||
| return parameters_[parameters_.size() - hyper_param_count_ - 1]; | |||
| return parameters_[(parameters_.size() - hyper_param_count_) - 1]; | |||
| } | |||
| std::string FuncGraph::GetVariableArgName() { | |||
| @@ -600,7 +600,8 @@ std::string FuncGraph::GetVariableArgName() { | |||
| MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " | |||
| << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; | |||
| } | |||
| const auto ¶meter = parameters_[parameters_.size() - hyper_param_count_ - min_param_num]->cast<ParameterPtr>(); | |||
| const auto ¶meter = | |||
| parameters_[(parameters_.size() - hyper_param_count_) - min_param_num]->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| return parameter->name(); | |||
| } | |||
| @@ -609,7 +610,7 @@ std::string FuncGraph::GetVariableArgName() { | |||
| MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " | |||
| << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; | |||
| } | |||
| const auto ¶meter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>(); | |||
| const auto ¶meter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| return parameter->name(); | |||
| } | |||
| @@ -620,7 +621,7 @@ AnfNodePtr FuncGraph::GetVariableKwargParameter() { | |||
| MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " | |||
| << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; | |||
| } | |||
| return parameters_[parameters_.size() - hyper_param_count_ - 1]; | |||
| return parameters_[(parameters_.size() - hyper_param_count_) - 1]; | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -631,7 +632,7 @@ std::string FuncGraph::GetVariableKwargName() { | |||
| MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " | |||
| << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; | |||
| } | |||
| const auto ¶meter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>(); | |||
| const auto ¶meter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| return parameter->name(); | |||
| } | |||
| @@ -646,7 +647,7 @@ int FuncGraph::GetPositionalArgsCount() const { | |||
| if (has_vararg_) { | |||
| count--; | |||
| } | |||
| return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); | |||
| return (count - kwonlyargs_count_) - SizeToInt(hyper_param_count_); | |||
| } | |||
| AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -95,13 +95,6 @@ const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry"; | |||
| const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad"; | |||
| const char kFuncGraphFlagRecursive[] = "Recursive"; | |||
| namespace abstract { | |||
| class AbstractKeywordArg; | |||
| using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>; | |||
| class AbstractFunction; | |||
| using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>; | |||
| } // namespace abstract | |||
| class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, public EffectInfoHolder { | |||
| public: | |||
| using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>; | |||
| @@ -348,14 +341,16 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi | |||
| bool stub() const { return stub_; } | |||
| void set_stub(bool stub) { stub_ = stub; } | |||
| static void set_drawer(Drawer drawer) { drawer_ = drawer; } | |||
| static void set_drawer(const Drawer &drawer) { drawer_ = drawer; } | |||
| std::shared_ptr<bool> switch_input() const { return switch_input_; } | |||
| void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; } | |||
| void set_switch_input(const std::shared_ptr<bool> &switch_input) { switch_input_ = switch_input; } | |||
| std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | |||
| void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | |||
| void set_switch_layer_input(const std::shared_ptr<bool> &switch_layer_input) { | |||
| switch_layer_input_ = switch_layer_input; | |||
| } | |||
| bool ContainMultiTarget(); | |||
| bool IsMultiTarget() const { return exist_multi_target_; } | |||
| int64_t stage() { return stage_; } | |||
| int64_t stage() const { return stage_; } | |||
| void set_stage(int64_t stage) { stage_ = stage; } | |||
| bool dropped() const { return dropped_; } | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -252,7 +252,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) | |||
| } | |||
| FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>()); | |||
| size_t kwarg_count = kwarg_list.size(); | |||
| int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_); | |||
| int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - hyper_param_count_); | |||
| int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); | |||
| int variable_args_count = pos_args_input_count - pos_args_count; | |||
| std::vector<AnfNodePtr> specialized_parameter_list; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -85,6 +85,8 @@ class KernelInfoDevice { | |||
| RuntimeCacheScope runtime_cache() { return RuntimeCacheScope(runtime_cache_, mu_); } | |||
| virtual ~KernelInfoDevice() {} | |||
| private: | |||
| RuntimeCache runtime_cache_; | |||
| std::mutex mu_; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2021-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -31,11 +31,6 @@ class AbstractBase; | |||
| using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>; | |||
| } // namespace abstract | |||
| namespace mindspore { | |||
| class Primitive; | |||
| using PrimitivePtr = std::shared_ptr<Primitive>; | |||
| } // namespace mindspore | |||
| namespace mindspore { | |||
| namespace ops { | |||
| class BaseOperator : public api::Primitive { | |||