Browse Source

change process of infer value

pull/15431/head
lianliguang 4 years ago
parent
commit
d2c696667b
11 changed files with 49 additions and 46 deletions
  1. +4
    -4
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  2. +13
    -8
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  3. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
  4. +7
    -7
      mindspore/core/abstract/primitive_infer_map.h
  5. +2
    -3
      mindspore/core/ops/dtype.cc
  6. +1
    -1
      mindspore/core/ops/primitive_c.cc
  7. +1
    -2
      mindspore/core/ops/shape.cc
  8. +10
    -10
      mindspore/core/ops/zeros.cc
  9. +4
    -4
      mindspore/core/utils/tensor_construct_utils.cc
  10. +4
    -4
      mindspore/core/utils/tensor_construct_utils.h
  11. +2
    -2
      tests/ut/cpp/abstract/abstract_test.cc

+ 4
- 4
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -924,16 +924,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis
auto ret = prim_eval_implement_map.find(prim);
if (ret != prim_eval_implement_map.end()) {
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_);
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list);
return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
} else {
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
auto ret_backend = prim_backend_eval_impl_map.find(prim);
if (ret_backend != prim_backend_eval_impl_map.end()) {
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_);
return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list);
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list);
}
}
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()


+ 13
- 8
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -576,7 +576,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
prim_py->RunCheck(py_args);

prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();

@@ -602,16 +602,21 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
AbstractBasePtr abs_base = nullptr;
ValuePtr value = nullptr;
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
if (need_infer_value) {
if (eval_impl_.infer_value_func_ != nullptr) {
auto value = eval_impl_.infer_value_func_(prim_, args, abs_base);
abs_base->set_value(value);
if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) {
value = eval_impl_.infer_value_impl_(prim_, args);
if (value != nullptr) {
abs_base = value->ToAbstract();
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
}
abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
return eval_result;
}


+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc View File

@@ -359,7 +359,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr

// find prim infer function in the prim function map return a standard evaluator
auto eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl.infer_shape_dtype_impl_ != nullptr) {
if (eval_impl.infer_shape_impl_ != nullptr) {
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}



+ 7
- 7
mindspore/core/abstract/primitive_infer_map.h View File

@@ -28,13 +28,13 @@

namespace mindspore {
namespace abstract {
using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
using InferShapeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &);

struct StandardPrimitiveImplReg {
InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive
InferValueEvalImpl infer_value_func_; // infer value of primitive
InferShapeImpl infer_shape_impl_; // infer shape and type for ops
InferValueImpl infer_value_impl_; // infer value for ops
// in_white_list_ is true means this primitive can be executed by vm backend
// else will be optimized by frontend
bool in_white_list_;
@@ -55,8 +55,8 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard

class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl,
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl,
const InferValueImpl &infer_value_impl, const bool is_wight_list = true) {
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}


+ 2
- 3
mindspore/core/ops/dtype.cc View File

@@ -27,8 +27,7 @@

namespace mindspore {
namespace ops {
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const AbstractBasePtr &infer) {
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
@@ -41,7 +40,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra

AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto value = DTypeInferValue(primitive, input_args, nullptr);
auto value = DTypeInferValue(primitive, input_args);
MS_EXCEPTION_IF_NULL(value);
auto type = value->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type);


+ 1
- 1
mindspore/core/ops/primitive_c.cc View File

@@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
if (iter == infer_map.end()) {
MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!";
}
auto infer_function = iter->second.infer_shape_dtype_impl_;
auto infer_function = iter->second.infer_shape_impl_;
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
}



+ 1
- 2
mindspore/core/ops/shape.cc View File

@@ -45,8 +45,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
return abs;
}

ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const AbstractBasePtr &infer) {
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);


+ 10
- 10
mindspore/core/ops/zeros.cc View File

@@ -49,9 +49,17 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
}
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args,
const abstract::AbstractBasePtr &abs) {
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
ZerosInferShape(primitive, input_args));
return abs;
}

ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto abs = ZerosInfer(nullptr, prim, input_args);
// check
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape];
auto out_type = abs->BuildType();
@@ -59,14 +67,6 @@ ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
return TensorConstructUtils::CreateZerosTensor(out_type, out_shape);
}
} // namespace

AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
ZerosInferShape(primitive, input_args));
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false);
} // namespace ops
} // namespace mindspore

+ 4
- 4
mindspore/core/utils/tensor_construct_utils.cc View File

@@ -17,7 +17,7 @@
#include <vector>
#include <memory>
namespace mindspore {
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_id = ExtractTypeId(type_ptr);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
@@ -30,7 +30,7 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr
return tensor;
}

tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_id = ExtractTypeId(type_ptr);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
@@ -43,7 +43,7 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr,
return tensor;
}

tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape,
tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape,
void *data) {
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_id = ExtractTypeId(type_ptr);
@@ -51,7 +51,7 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con
return tensor;
}

TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) {
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) {
MS_EXCEPTION_IF_NULL(type_ptr);
TypeId type_id;
if (type_ptr->isa<TensorType>()) {


+ 4
- 4
mindspore/core/utils/tensor_construct_utils.h View File

@@ -28,12 +28,12 @@ void SetTensorData(void *data, T num, size_t data_length) {
}
class TensorConstructUtils {
public:
static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data);
static tensor::TensorPtr CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape, void *data);

private:
static TypeId ExtractTypeId(const TypePtr type);
static TypeId ExtractTypeId(const TypePtr &type);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_

+ 2
- 2
tests/ut/cpp/abstract/abstract_test.cc View File

@@ -87,9 +87,9 @@ TEST_F(TestAbstract, TestParseDataClass) {
AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};

auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_);
ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_);

AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list);
AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list);
ASSERT_TRUE(nullptr != new_cls);
}



Loading…
Cancel
Save