Browse Source

add ops infer

pull/14250/head
simson 4 years ago
parent
commit
15bda34dcb
7 changed files with 209 additions and 1 deletions
  1. +1
    -1
      mindspore/core/ops/op_utils.h
  2. +48
    -0
      mindspore/core/ops/scalar_summary.cc
  3. +43
    -0
      mindspore/core/ops/scalar_summary.h
  4. +48
    -0
      mindspore/core/ops/tensor_summary.cc
  5. +43
    -0
      mindspore/core/ops/tensor_summary.h
  6. +23
    -0
      mindspore/core/utils/check_convert_utils.cc
  7. +3
    -0
      mindspore/core/utils/check_convert_utils.h

+ 1
- 1
mindspore/core/ops/op_utils.h View File

@@ -229,7 +229,7 @@ constexpr auto kZoneoutHidden = "zoneout_hidden";
constexpr auto kSpliceContext = "context";
constexpr auto kSpliceForwardIndexes = "forward_indexes";
constexpr auto kSpliceOutputDims = "output_dim";
constexpr auto kSideEffectIO = "side_effect_io";
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};



+ 48
- 0
mindspore/core/ops/scalar_summary.cc View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/scalar_summary.h"
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {

void ScalarSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); }

bool ScalarSummary::get_side_effect_io() const {
auto value_ptr = GetAttr(kSideEffectIO);
return GetValue<bool>(value_ptr);
}

void ScalarSummary::Init() { this->set_side_effect_io(); }

AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer);
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
} // namespace ops
} // namespace mindspore

+ 43
- 0
mindspore/core/ops/scalar_summary.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_SCALAR_SUMMARY_H_
#define MINDSPORE_CORE_OPS_SCALAR_SUMMARY_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameScalarSummary = "ScalarSummary";
class ScalarSummary : public PrimitiveC {
public:
ScalarSummary() : PrimitiveC(kNameScalarSummary) {}
~ScalarSummary() = default;
MS_DECLARE_PARENT(ScalarSummary, PrimitiveC);
void Init();
void set_side_effect_io();
bool get_side_effect_io() const;
};
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_SCALAR_SUMMARY_H_

+ 48
- 0
mindspore/core/ops/tensor_summary.cc View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/tensor_summary.h"
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {

void TensorSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); }

bool TensorSummary::get_side_effect_io() const {
auto value_ptr = GetAttr(kSideEffectIO);
return GetValue<bool>(value_ptr);
}

void TensorSummary::Init() { this->set_side_effect_io(); }

AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer);
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
} // namespace ops
} // namespace mindspore

+ 43
- 0
mindspore/core/ops/tensor_summary.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_TENSOR_SUMMARY_H_
#define MINDSPORE_CORE_OPS_TENSOR_SUMMARY_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameTensorSummary = "TensorSummary";
class TensorSummary : public PrimitiveC {
public:
TensorSummary() : PrimitiveC(kNameTensorSummary) {}
~TensorSummary() = default;
MS_DECLARE_PARENT(TensorSummary, PrimitiveC);
void Init();
void set_side_effect_io();
bool get_side_effect_io() const;
};
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_TENSOR_SUMMARY_H_

+ 23
- 0
mindspore/core/utils/check_convert_utils.cc View File

@@ -27,6 +27,7 @@
#include "ir/dtype/type.h"
#include "ir/dtype/tensor_type.h"
#include "ir/dtype.h"
#include "utils/ms_context.h"

namespace mindspore {
static std::map<std::string, int64_t> DataFormatToEnumMap = {
@@ -563,4 +564,26 @@ bool CheckAndConvertUtils::CheckIrAttrtoOpAttr(const std::string &op_type, const
MS_LOG(DEBUG) << "convert ir attr to op attr, name: " << op_type << ", attr: " << attr_name;
return true;
}

void CheckAndConvertUtils::CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
const std::string &class_name) {
MS_EXCEPTION_IF_NULL(name);
MS_EXCEPTION_IF_NULL(value);
CheckMode(class_name);
CheckTypeValid("name", name->BuildType(), {kString}, class_name);
auto s = GetValue<std::string>(name->BuildValue());
if (s.empty()) {
MS_EXCEPTION(ValueError) << "For 'name' the value should by valid string in " << class_name
<< ", but got an empty string.";
}
CheckTypeValid("value", value->BuildType(), {kTensorType}, class_name);
}

void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode.";
}
}
} // namespace mindspore

+ 3
- 0
mindspore/core/utils/check_convert_utils.h View File

@@ -308,6 +308,9 @@ class CheckAndConvertUtils {
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
const std::string &class_name);
static void CheckMode(const std::string &class_name);

private:
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);


Loading…
Cancel
Save