Browse Source

!7465 [MSLITE][Develop] fix smart reply kernel

Merge pull request !7465 from sunsuodong/fix_lite_kernel
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a3d88f27e1
10 changed files with 55 additions and 49 deletions
  1. +2
    -2
      mindspore/lite/src/common/string_util.cc
  2. +3
    -2
      mindspore/lite/src/ops/custom_extract_features.cc
  3. +8
    -6
      mindspore/lite/src/ops/custom_normalize.cc
  4. +6
    -1
      mindspore/lite/src/ops/custom_predict.cc
  5. +1
    -4
      mindspore/lite/src/ops/hashtable_lookup.cc
  6. +13
    -14
      mindspore/lite/src/ops/lsh_projection.cc
  7. +5
    -1
      mindspore/lite/src/ops/skip_gram.cc
  8. +10
    -15
      mindspore/lite/src/runtime/kernel/arm/fp32/lsh_projection.cc
  9. +0
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/lsh_projection.h
  10. +7
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/skip_gram.cc

+ 2
- 2
mindspore/lite/src/common/string_util.cc View File

@@ -116,13 +116,13 @@ static uint64_t k2 = 0x9ae16a3b2f90404fULL;
uint64_t Fetch64Bit(const char *p) {
uint64_t result;
memcpy(&result, p, sizeof(uint64_t));
return __builtin_bswap64(result);
return result;
}

uint32_t Fetch32Bit(const char *p) {
uint32_t result;
memcpy(&result, p, sizeof(uint32_t));
return __builtin_bswap32(result);
return result;
}

uint64_t Rotate64(uint64_t value, int shift) {


+ 3
- 2
mindspore/lite/src/ops/custom_extract_features.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/custom_extract_features.h"

#include "src/common/string_util.h"

namespace mindspore {
@@ -40,9 +41,9 @@ int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector
MS_ASSERT(output0 != nullptr);
MS_ASSERT(output1 != nullptr);

output0->set_data_type(input->data_type());
output0->set_data_type(kNumberTypeInt32);
output0->SetFormat(input->GetFormat());
output1->set_data_type(input->data_type());
output1->set_data_type(kNumberTypeFloat32);
output1->SetFormat(input->GetFormat());

if (input->data_c() == nullptr) {


+ 8
- 6
mindspore/lite/src/ops/custom_normalize.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/custom_normalize.h"

#include "src/common/string_util.h"

namespace mindspore {
@@ -32,21 +33,22 @@ int CustomNormalize::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
#endif
int CustomNormalize::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.at(0);
auto output = outputs_.at(0);
MS_ASSERT(input != nullptr);
MS_ASSERT(output != nullptr);

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());

if (input->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
int string_num = lite::GetStringCount(input);
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);

std::vector<int> shape;
int string_num = lite::GetStringCount(input);
shape.push_back(string_num == 0 ? 1 : string_num);

output->set_shape(shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}



+ 6
- 1
mindspore/lite/src/ops/custom_predict.cc View File

@@ -35,7 +35,12 @@ float CustomPredict::GetWeightThreshold() const {
int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateCustomPredict(*fbb);
auto attr = primitive->value_as_CustomPredict();
if (attr == nullptr) {
MS_LOG(ERROR) << "CustomPredict attr is nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateCustomPredict(*fbb, attr->outputNum(), attr->weightThreshold());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_CustomPredict, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;


+ 1
- 4
mindspore/lite/src/ops/hashtable_lookup.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/hashtable_lookup.h"

#include "src/common/string_util.h"

namespace mindspore {
@@ -54,10 +55,6 @@ int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
int string_num = lite::GetStringCount(input);
std::vector<int> output_shape;
output_shape.push_back(string_num == 0 ? 1 : string_num);
output->set_shape(output_shape);
return RET_OK;
}
} // namespace lite


+ 13
- 14
mindspore/lite/src/ops/lsh_projection.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/lsh_projection.h"

#include "nnacl/lsh_projection_parameter.h"

namespace mindspore {
@@ -27,16 +28,17 @@ int LshProjection::GetLshType() const { return this->primitive_->value_as_LshPro
int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateLshProjection(*fbb);
auto attr = primitive->value_as_LshProjection();
if (attr == nullptr) {
MS_LOG(ERROR) << "LshProjection attr is nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateLshProjection(*fbb, attr->type());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LshProjection, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
namespace {
constexpr int kSparseType = 1;
constexpr int kDenseType = 2;
} // namespace
int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) {
MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given.";
@@ -47,29 +49,26 @@ int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor
return RET_ERROR;
}

auto in_hash = inputs_.at(kSingleNum);
auto in_hash = inputs_.at(0);
MS_ASSERT(in_hash->shape().size() == 2);
MS_ASSERT(in_hash->DimensionSize(1) <= 32);
MS_ASSERT(inputs_.at(kDoubleNum)->shape().size() >= 1);
MS_ASSERT(inputs_.at(1)->shape().size() >= 1);

if (inputs_.size() == kMultiNum) {
MS_ASSERT(inputs_.at(kMultiNum)->shape().size() == 1);
MS_ASSERT(inputs_.at(kMultiNum)->DimensionSize(0) == in_value->DimensionSize(0));
MS_ASSERT(inputs_.at(2)->shape().size() == 1);
MS_ASSERT(inputs_.at(2)->DimensionSize(0) == in_value->DimensionSize(0));
}

auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeInt32);
out_tensor->SetFormat(schema::Format::Format_NHWC);
if (!GetInferFlag()) {
return RET_OK;
}

std::vector<int> out_shape;
switch (GetLshType()) {
case kSparseType:
case schema::LshProjectionType_SPARSE:
out_shape.push_back(in_hash->DimensionSize(0));
break;
case kDenseType:
case schema::LshProjectionType_DENSE:
out_shape.push_back(in_hash->DimensionSize(0) * in_hash->DimensionSize(1));
break;
default:


+ 5
- 1
mindspore/lite/src/ops/skip_gram.cc View File

@@ -68,7 +68,11 @@ int SkipGram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());

return RET_INFER_INVALID;
if (input->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 10
- 15
mindspore/lite/src/runtime/kernel/arm/fp32/lsh_projection.cc View File

@@ -13,12 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/lsh_projection.h"

#include "include/errorcode.h"
#include "src/common/string_util.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "src/common/string_util.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -28,12 +28,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LshProjection;

namespace mindspore::kernel {

namespace {
constexpr int kSparseType = 1;
constexpr int kDenseType = 2;
} // namespace

int LshProjectionCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
@@ -91,10 +85,10 @@ int LshProjectionCPUKernel::DoExecute(int task_id) {
}

switch (lsh_param_->lsh_type_) {
case kSparseType:
case schema::LshProjectionType_SPARSE:
LshProjectionSparse(hash, in_data, weight, output, lsh_param_);
break;
case kDenseType:
case schema::LshProjectionType_DENSE:
LshProjectionDense(hash, in_data, weight, output, lsh_param_);
break;
default:
@@ -106,7 +100,7 @@ int LshProjectionCPUKernel::DoExecute(int task_id) {
int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para) {
double score = 0.0;
for (int i = 0; i < para->in_item_num_; i++) {
char *key = static_cast<char *>(ctx_->allocator->Malloc(lsh_param_->key_size_));
char *key = static_cast<char *>(context_->allocator->Malloc(lsh_param_->key_size_));
if (key == nullptr) {
MS_LOG(ERROR) << "malloc key failed.";
return RET_ERROR;
@@ -114,13 +108,14 @@ int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed,
memcpy(key, &seed, para->seed_size_);
memcpy(key + para->seed_size_, in_data, para->in_item_size_);
in_data += para->in_item_size_;
double hash_sign = static_cast<double>(mindspore::lite::StringHash64(key, para->key_size_));
int64_t hash_i = static_cast<int64_t>(mindspore::lite::StringHash64(key, para->key_size_));
double hash_d = static_cast<double>(hash_i);
if (weight == nullptr) {
score += hash_sign;
score += hash_d;
} else {
score += weight[i] * hash_sign;
score += weight[i] * hash_d;
}
ctx_->allocator->Free(key);
context_->allocator->Free(key);
}
return (score > 0) ? 1 : 0;
}


+ 0
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/lsh_projection.h View File

@@ -21,7 +21,6 @@

#include "nnacl/lsh_projection_parameter.h"
#include "src/lite_kernel.h"
#include "schema/model_generated.h"

namespace mindspore::kernel {
class LshProjectionCPUKernel : public LiteKernel {
@@ -44,7 +43,6 @@ class LshProjectionCPUKernel : public LiteKernel {

private:
LshProjectionParameter *lsh_param_ = nullptr;
const lite::InnerContext *ctx_;
int thread_num_;
int64_t elements_num_;
int64_t count_unit_;


+ 7
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/skip_gram.cc View File

@@ -15,6 +15,7 @@
*/

#include "src/runtime/kernel/arm/fp32/skip_gram.h"

#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
@@ -59,6 +60,11 @@ void ParseSentenceToWords(const StringPack &sentence, std::vector<StringPack> *w
}

int SkipGramCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
skip_gram_parameter_ = reinterpret_cast<SkipGramParameter *>(op_parameter_);
if (skip_gram_parameter_->ngram_size < 1) {
MS_LOG(ERROR) << "Skip Gram Parameter Error, NgramSize should be at least 1, get "
@@ -99,8 +105,7 @@ int SkipGramCPUKernel::Run() {
index--;
}
}

int ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
ret = mindspore::lite::WriteSeperatedStringsToTensor(out_tensors_[0], result);
return ret;
}



Loading…
Cancel
Save