Browse Source

supplement softmax UnPackAttr

tags/v1.0.0
lyvette 5 years ago
parent
commit
a4617f667f
3 changed files with 33 additions and 0 deletions
  1. +2
    -0
      mindspore/lite/src/ops/primitive_c.cc
  2. +30
    -0
      mindspore/lite/src/ops/softmax.cc
  3. +1
    -0
      mindspore/lite/src/ops/softmax.h

+ 2
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -257,6 +257,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
return NewPrimitiveC<Transpose>(prim, inputs);
} else if (op_type == "tuple_getitem") {
return NewPrimitiveC<TupleGetItem>(prim, inputs);
} else if (op_type == "Softmax") {
return NewPrimitiveC<SoftMax>(prim, inputs);
} else {
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
return nullptr;


+ 30
- 0
mindspore/lite/src/ops/softmax.cc View File

@@ -21,6 +21,36 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int SoftMax::GetAxis() const { return this->primitive_->value.AsSoftMax()->axis; }

int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_SoftMax;
}
if (this->primitive_->value.type != schema::PrimitiveType_SoftMax) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SoftMaxT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = prim_axis;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}

void SoftMax::SetAxis(int axis) { this->primitive_->value.AsSoftMax()->axis = axis; }

#else


+ 1
- 0
mindspore/lite/src/ops/softmax.h View File

@@ -31,6 +31,7 @@ class SoftMax : public PrimitiveC {
MS_DECLARE_PARENT(SoftMax, PrimitiveC);
SoftMax() = default;
explicit SoftMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAxis(int axis);

#else


Loading…
Cancel
Save