Browse Source

fix arithmetic compare, matmul, logicalnot, constant_folding_fusion

tags/v1.1.0
gongdaguo 5 years ago
parent
commit
815b7af9ec
23 changed files with 244 additions and 184 deletions
  1. +33
    -0
      mindspore/lite/src/ops/arithmetic_compare.cc
  2. +41
    -0
      mindspore/lite/src/ops/arithmetic_compare.h
  3. +0
    -10
      mindspore/lite/src/ops/equal.cc
  4. +4
    -5
      mindspore/lite/src/ops/equal.h
  5. +0
    -10
      mindspore/lite/src/ops/greater.cc
  6. +4
    -5
      mindspore/lite/src/ops/greater.h
  7. +0
    -10
      mindspore/lite/src/ops/greater_equal.cc
  8. +4
    -5
      mindspore/lite/src/ops/greater_equal.h
  9. +0
    -10
      mindspore/lite/src/ops/less.cc
  10. +4
    -5
      mindspore/lite/src/ops/less.h
  11. +0
    -10
      mindspore/lite/src/ops/less_equal.cc
  12. +4
    -5
      mindspore/lite/src/ops/less_equal.h
  13. +17
    -3
      mindspore/lite/src/ops/matmul.cc
  14. +0
    -10
      mindspore/lite/src/ops/not_equal.cc
  15. +4
    -5
      mindspore/lite/src/ops/not_equal.h
  16. +2
    -2
      mindspore/lite/src/ops/populate/arithmetic_populate.cc
  17. +64
    -59
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc
  18. +37
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.h
  19. +9
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  20. +6
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h
  21. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc
  22. +10
    -10
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc
  23. +0
    -9
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc

+ 33
- 0
mindspore/lite/src/ops/arithmetic_compare.cc View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/arithmetic_compare.h"

namespace mindspore {
namespace lite {

int ArithmeticCompare::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto res = Arithmetic::InferShape(inputs_, outputs_);
if (res == RET_OK) {
auto output = outputs_.front();
output->set_data_type(TypeId::kNumberTypeBool);
return RET_OK;
} else {
return res;
}
}
} // namespace lite
} // namespace mindspore

+ 41
- 0
mindspore/lite/src/ops/arithmetic_compare.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_

#include <vector>
#include <set>
#include <cmath>

#include "src/ops/arithmetic.h"

namespace mindspore {
namespace lite {
class ArithmeticCompare : public Arithmetic {
public:
ArithmeticCompare() = default;
~ArithmeticCompare() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic);
explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_

+ 0
- 10
mindspore/lite/src/ops/equal.cc View File

@@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); } PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); }
Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator);
#endif #endif
int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/equal.h View File

@@ -20,21 +20,20 @@
#include <vector> #include <vector>
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Equal : public Arithmetic {
class Equal : public ArithmeticCompare {
public: public:
Equal() = default; Equal() = default;
~Equal() = default; ~Equal() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, PrimitiveC);
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(Equal, ArithmeticCompare);
explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 0
- 10
mindspore/lite/src/ops/greater.cc View File

@@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); } PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); }
Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator);
#endif #endif
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/greater.h View File

@@ -20,21 +20,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>


#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Greater : public Arithmetic {
class Greater : public ArithmeticCompare {
public: public:
Greater() = default; Greater() = default;
~Greater() = default; ~Greater() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, Arithmetic);
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(Greater, ArithmeticCompare);
explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 0
- 10
mindspore/lite/src/ops/greater_equal.cc View File

@@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) {
Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator);


#endif #endif
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/greater_equal.h View File

@@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>


#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class GreaterEqual : public Arithmetic {
class GreaterEqual : public ArithmeticCompare {
public: public:
GreaterEqual() = default; GreaterEqual() = default;
~GreaterEqual() = default; ~GreaterEqual() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GreaterEqual, Arithmetic);
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare);
explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 0
- 10
mindspore/lite/src/ops/less.cc View File

@@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); Registry LessRegistry(schema::PrimitiveType_Less, LessCreator);


#endif #endif
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/less.h View File

@@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>


#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Less : public Arithmetic {
class Less : public ArithmeticCompare {
public: public:
Less() = default; Less() = default;
~Less() = default; ~Less() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Less, Arithmetic);
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(Less, ArithmeticCompare);
explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 0
- 10
mindspore/lite/src/ops/less_equal.cc View File

@@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) {
} }
Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator);
#endif #endif
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/less_equal.h View File

@@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>


#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class LessEqual : public Arithmetic {
class LessEqual : public ArithmeticCompare {
public: public:
LessEqual() = default; LessEqual() = default;
~LessEqual() = default; ~LessEqual() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LessEqual, Arithmetic);
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(LessEqual, ArithmeticCompare);
explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 17
- 3
mindspore/lite/src/ops/matmul.cc View File

@@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
input0->set_shape(a_shape); input0->set_shape(a_shape);
} }


if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid";
return RET_INPUT_TENSOR_ERROR;
bool del_start = false;
bool del_end = false;
if (a_shape.size() == 1) {
a_shape.insert(a_shape.begin(), 1);
input0->set_shape(a_shape);
del_start = true;
}
if (b_shape.size() == 1) {
b_shape.push_back(1);
input1->set_shape(b_shape);
del_end = true;
} }
for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) {
if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) {
@@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
} }
std::vector<int> c_shape(a_shape); std::vector<int> c_shape(a_shape);
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
if (del_start) {
c_shape.erase(c_shape.begin());
}
if (del_end) {
c_shape.pop_back();
}
output->set_shape(c_shape); output->set_shape(c_shape);
return RET_OK; return RET_OK;
} }


+ 0
- 10
mindspore/lite/src/ops/not_equal.cc View File

@@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) {
Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator);


#endif #endif
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}


} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/not_equal.h View File

@@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>


#include "src/ops/arithmetic.h"
#include "src/ops/arithmetic_compare.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class NotEqual : public Arithmetic {
class NotEqual : public ArithmeticCompare {
public: public:
NotEqual() = default; NotEqual() = default;
~NotEqual() = default; ~NotEqual() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(NotEqual, Arithmetic);
explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
MS_DECLARE_PARENT(NotEqual, ArithmeticCompare);
explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 2
- 2
mindspore/lite/src/ops/populate/arithmetic_populate.cc View File

@@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme
Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic);
Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic);
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);


+ 64
- 59
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc View File

@@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_NotEqual; using mindspore::schema::PrimitiveType_NotEqual;


namespace mindspore::kernel { namespace mindspore::kernel {
namespace {
typedef struct {
int primitive_type_;
ArithmeticCompareFp32Func func_;
} TYPE_FUNC_INFO;
} // namespace


ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) {
TYPE_FUNC_INFO type_func_table[] = {
{PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32},
{PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32},
{PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}};
for (size_t i = 0; i < sizeof(type_func_table); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_;
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
if (data_type_ == kDataTypeInt) {
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
} }
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
} }
return nullptr;
}

int ArithmeticCompareCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code;
if (data_type_ == kDataTypeInt) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
}
if (error_code != RET_OK) {
return error_code;
}
} }
return ReSize();
return RET_OK;
} }


int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; }
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();


int ArithmeticCompareCPUKernel::DoExecute(int task_id) {
if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) {
MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! ";
return RET_ERROR;
}
int elements_num = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
int offset = task_id * stride;
int count = MSMIN(stride, elements_num - offset);
if (count <= 0) {
return RET_OK;
}
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);

if (func_fp32_ == nullptr) {
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";
return RET_ERROR; return RET_ERROR;
} }
// two inputs have the same shape, support broadcast later
auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData());
auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run failed, illegal input! ";
}
return ret;
}


int ArithmeticCompareRun(void *cdata, int task_id) {
auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata);
auto ret = kernel->DoExecute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]";
int error_code;
if (arithmeticParameter_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
} else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) {
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else {
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
}
} }
return ret;
}

int ArithmeticCompareCPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
if (error_code != RET_OK) {
return RET_ERROR;
} }
return ret;
return RET_OK;
} }


kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,


+ 37
- 7
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.h View File

@@ -18,27 +18,57 @@


#include <vector> #include <vector>
#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/arithmetic_compare_fp32.h"


namespace mindspore::kernel { namespace mindspore::kernel {
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
public: public:
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {
func_ = GetArithmeticCompareFun(parameter->type_);
switch (parameter->type_) {
case PrimitiveType_Equal:
func_fp32_ = ElementEqualFp32;
func_int32_ = ElementEqualInt32;
break;
case PrimitiveType_NotEqual:
func_fp32_ = ElementNotEqualFp32;
func_int32_ = ElementNotEqualInt32;
break;
case PrimitiveType_Less:
func_fp32_ = ElementLessFp32;
func_int32_ = ElementLessInt32;
break;
case PrimitiveType_LessEqual:
func_fp32_ = ElementLessEqualFp32;
func_int32_ = ElementLessEqualInt32;
break;
case PrimitiveType_Greater:
func_fp32_ = ElementGreaterFp32;
func_int32_ = ElementGreaterInt32;
break;
case PrimitiveType_GreaterEqual:
func_fp32_ = ElementGreaterEqualFp32;
func_int32_ = ElementGreaterEqualInt32;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
func_fp32_ = nullptr;
func_int32_ = nullptr;
break;
}
} }
~ArithmeticCompareCPUKernel() override = default; ~ArithmeticCompareCPUKernel() override = default;


int Init() override;
int ReSize() override;
int Run() override;
virtual int DoExecute(int task_id);
int DoArithmetic(int task_id) override;
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;


private: private:
ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type);
ArithmeticCompareFp32Func func_;
ArithmeticCompareFp32Func func_fp32_ = nullptr;
ArithmeticCompareIntFunc func_int32_ = nullptr;
}; };
int ArithmeticCompareRun(void *cdata, int task_id); int ArithmeticCompareRun(void *cdata, int task_id);
} // namespace mindspore::kernel } // namespace mindspore::kernel


+ 9
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() {
break; break;
} }
break; break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
break;
default: default:
break; break;
} }


+ 6
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h View File

@@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel {
int PreProcess() override; int PreProcess() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoArithmetic(int task_id);
virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);


private:
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
protected:
int break_pos_ = 0; int break_pos_ = 0;
int outside_ = 0; int outside_ = 0;
int thread_count_ = 1; int thread_count_ = 1;
ArithmeticParameter *arithmeticParameter_ = nullptr; ArithmeticParameter *arithmeticParameter_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;

private:
ArithmeticRun arithmetic_run_ = nullptr; ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr;
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_

+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc View File

@@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 10
- 10
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc View File

@@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) { if (nodeIter == onnx_graph.initializer().end()) {
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight;
return RET_ERROR;
}
std::vector<int> weight_shape;
auto size = (*nodeIter).dims_size();
weight_shape.reserve(size);
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*nodeIter).dims(i));
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
} else {
std::vector<int> weight_shape;
auto size = (*nodeIter).dims_size();
weight_shape.reserve(size);
for (int i = 0; i < size; ++i) {
weight_shape.emplace_back((*nodeIter).dims(i));
}
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
} }
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
} else { } else {
auto nodeIter = auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),


+ 0
- 9
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
output_tensors[m]->AddQuantParam(quant_arg); output_tensors[m]->AddQuantParam(quant_arg);
} }
} }
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for (auto &input_tensor : input_tensors) {
input_tensor->set_format(schema::Format::Format_NHWC);
if (input_tensor->shape().size() == 4) {
MS_LOG(INFO) << "init input_tensor format to nhwc";
}
}
lite_primitive->InferShape(input_tensors, output_tensors); lite_primitive->InferShape(input_tensors, output_tensors);
auto primitive = lite_primitive.get(); auto primitive = lite_primitive.get();
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);


Loading…
Cancel
Save