Browse Source

!11794 remove useless code of dot

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
066ebe516e
12 changed files with 9 additions and 77 deletions
  1. +1
    -2
      mindspore/_extends/parse/resources.py
  2. +2
    -2
      mindspore/ccsrc/frontend/optimizer/graph_transform.cc
  3. +2
    -3
      mindspore/ccsrc/pipeline/jit/resource.cc
  4. +0
    -2
      mindspore/core/abstract/infer_functions.h
  5. +1
    -30
      mindspore/core/abstract/prim_statement.cc
  6. +0
    -1
      mindspore/core/abstract/primitive_infer_map.cc
  7. +0
    -1
      mindspore/core/base/core_ops.h
  8. +1
    -7
      mindspore/ops/_grad/grad_implementations.py
  9. +0
    -1
      mindspore/ops/functional.py
  10. +1
    -6
      tests/ut/cpp/operator/ops_test.cc
  11. +0
    -5
      tests/ut/cpp/optimizer/ad/ad_test.cc
  12. +1
    -17
      tests/ut/cpp/pipeline/static_analysis/prim_test.cc

+ 1
- 2
mindspore/_extends/parse/resources.py View File

@@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -86,7 +86,6 @@ convert_object_map = {
T.floordiv: multitype_ops.floordiv,
T.mod: multitype_ops.mod,
T.pow: multitype_ops.pow_,
T.matmul: F.dot,
T.lshift: NO_IMPLEMENT,
T.rshift: NO_IMPLEMENT,
T.and_: multitype_ops.logical_and,


+ 2
- 2
mindspore/ccsrc/frontend/optimizer/graph_transform.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ bool CNodeHasTupleInput(const CNodePtr &cnode) {
}
if (IsValueNode<Primitive>(inputs[i])) {
// unexpected high order primitvie as cnode input when transform graph
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString();
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitive as input" << cnode->DebugString();
return false;
}
auto abs = inputs[i]->abstract();


+ 2
- 3
mindspore/ccsrc/pipeline/jit/resource.cc View File

@@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -170,7 +170,6 @@ BuiltInTypeMap &GetMethodMap() {
{"__ge__", std::string("ge")}, // C.ge
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
{"view", std::string("view")}, // C.view
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
@@ -352,7 +351,7 @@ void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString();
MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
pynative_short_life_primitives_.insert(prim);
pynative_new_primtives_squence_.push_back(prim->ToString());
}


+ 0
- 2
mindspore/core/abstract/infer_functions.h View File

@@ -27,8 +27,6 @@ namespace mindspore {
namespace abstract {
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,


+ 1
- 30
mindspore/core/abstract/prim_statement.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -34,35 +34,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
return abs_base;
}

AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);

ShapePtr x_shp = input_x->shape();
auto x_shp_value = x_shp->shape();
ShapePtr y_shp = input_y->shape();
auto y_shp_value = y_shp->shape();
// Should be matrix which shape size is 2.
if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
<< x_shp_value.size() << ", " << y_shp_value.size() << " ";
}
if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
}

auto x_element = input_x->element();
MS_EXCEPTION_IF_NULL(x_element);
(void)x_element->Join(input_y->element());
auto param = {x_shp_value[0], y_shp_value[1]};

return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
}

AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
const AbstractBasePtrList &args_spec_list) {
// Inputs: condition, true branch, false branch


+ 0
- 1
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -26,7 +26,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}},


+ 0
- 1
mindspore/core/base/core_ops.h View File

@@ -67,7 +67,6 @@ inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalO
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");

inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");


+ 1
- 7
mindspore/ops/_grad/grad_implementations.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -188,12 +188,6 @@ def bprop_array_to_scalar(x, out, dout):
return (F.scalar_to_array(dout),)


@bprops.register("dot")
def bprop_dot(x, y, out, dout):
"""Backpropagator for primitive `dot`."""
return F.dot(dout, F.transpose(y, (1, 0))), F.dot(F.transpose(x, (1, 0)), dout)


@bprops.register("reshape")
def bprop_reshape(xs, shp, out, dout):
"""Backpropagator for primitive `reshape`."""


+ 0
- 1
mindspore/ops/functional.py View File

@@ -142,7 +142,6 @@ in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict")
mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
zeros_like = P.ZerosLike()
distribute = Primitive('distribute')


+ 1
- 6
tests/ut/cpp/operator/ops_test.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -287,11 +287,6 @@ TEST_F(TestOps, TransposeTest) {
ASSERT_EQ(prim->name(), kPrimTranspose->name());
}

TEST_F(TestOps, DotTest) {
auto prim = std::make_shared<Primitive>("dot");
ASSERT_EQ(prim->name(), kPrimDot->name());
}

TEST_F(TestOps, Im2ColTest) {
auto prim = std::make_shared<Primitive>("im2col");
ASSERT_EQ(prim->name(), kPrimIm2Col->name());


+ 0
- 5
tests/ut/cpp/optimizer/ad/ad_test.cc View File

@@ -169,11 +169,6 @@ TEST_F(TestAD, test_prim_array_to_scalar) {
AssertExpect("test_prim_array_to_scalar", dg);
}

TEST_F(TestAD, test_prim_dot) {
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDot), resourcePtr);
AssertExpect("test_prim_dot", dg);
}

TEST_F(TestAD, test_prim_distribute) {
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDistribute), resourcePtr);
AssertExpect("test_prim_distribute", dg);


+ 1
- 17
tests/ut/cpp/pipeline/static_analysis/prim_test.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -291,22 +291,6 @@ TEST_F(TestPrim, test_J_2) {
ASSERT_TRUE(res_J_1 != nullptr);
}

TEST_F(TestPrim, test_dot) {
auto dot = std::make_shared<Primitive>("dot");
FuncGraphPtr func_graph = MakeFuncGraph(dot, 2);

auto a1 = UTPrimUtils::ArrayFloat64Of({2, 3});
auto a2 = UTPrimUtils::ArrayFloat64Of({3, 4});
std::vector<int64_t> expectedA = {2, 4};
auto expected = UTPrimUtils::ArrayFloat64Of({2, 4});

AbstractBasePtrList args_spec_list = {a1, a2};

AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());

ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
}

// tail half
TEST_F(TestPrim, test_switch1) {
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");


Loading…
Cancel
Save