|
|
|
@@ -25,6 +25,39 @@ int Cast::GetDstT() const { return this->primitive_->value.AsCast()->dstT; } |
|
|
|
void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; } |
|
|
|
void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; } |
|
|
|
|
|
|
|
int Cast::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT; |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->primitive_->value.type = schema::PrimitiveType_Cast; |
|
|
|
} |
|
|
|
if (this->primitive_->value.type != schema::PrimitiveType_Cast) { |
|
|
|
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
auto attr = new (std::nothrow) schema::CastT(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT value failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto srcAnf = reinterpret_cast<mindspore::Number *>(prim.GetAttr("SrcT").get()); |
|
|
|
auto dstAnf = reinterpret_cast<mindspore::Number *>(prim.GetAttr("DstT").get()); |
|
|
|
attr->srcT = srcAnf->number_type(); |
|
|
|
attr->dstT = dstAnf->number_type(); |
|
|
|
this->primitive_->value.value = attr; |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "primitive value is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
#else |
|
|
|
int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { |
|
|
|
MS_ASSERT(nullptr != primitive); |
|
|
|
|