|
|
|
@@ -266,8 +266,8 @@ void PrimitiveC::PopulaterInputQuantParam(const Primitive &prim, const std::vect |
|
|
|
if (filterMin != nullptr && filterMax != nullptr) { |
|
|
|
auto filterMinPtr = filterMin->cast<TensorPtr>(); |
|
|
|
auto filterMaxPtr = filterMax->cast<TensorPtr>(); |
|
|
|
float *minBuf = static_cast<float *>(filterMinPtr->data_c()); |
|
|
|
float *maxBuf = static_cast<float *>(filterMaxPtr->data_c()); |
|
|
|
auto *minBuf = static_cast<float *>(filterMinPtr->data_c()); |
|
|
|
auto *maxBuf = static_cast<float *>(filterMaxPtr->data_c()); |
|
|
|
quantParam.min = FLT_MAX; |
|
|
|
quantParam.max = FLT_MIN; |
|
|
|
for (int i = 0; i < filterMinPtr->ElementsNum(); ++i) { |
|
|
|
@@ -296,8 +296,8 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan |
|
|
|
if (outputMin != nullptr && outputMax != nullptr) { |
|
|
|
auto outputMinPtr = outputMin->cast<TensorPtr>(); |
|
|
|
auto outputMaxPtr = outputMax->cast<TensorPtr>(); |
|
|
|
float *minBuf = static_cast<float *>(outputMinPtr->data_c()); |
|
|
|
float *maxBuf = static_cast<float *>(outputMaxPtr->data_c()); |
|
|
|
auto *minBuf = static_cast<float *>(outputMinPtr->data_c()); |
|
|
|
auto *maxBuf = static_cast<float *>(outputMaxPtr->data_c()); |
|
|
|
quantParam.min = *minBuf; |
|
|
|
quantParam.max = *maxBuf; |
|
|
|
auto ret = quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, |
|
|
|
@@ -317,14 +317,14 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan |
|
|
|
|
|
|
|
void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { |
|
|
|
auto narrow_range = prim.GetAttr("narrow_range"); |
|
|
|
bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue<bool>(narrow_range) : false; |
|
|
|
bool narrowRangeQuantParam = narrow_range != nullptr && GetValue<bool>(narrow_range); |
|
|
|
auto num_bits = prim.GetAttr("num_bits"); |
|
|
|
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8; |
|
|
|
PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam); |
|
|
|
PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam); |
|
|
|
} |
|
|
|
|
|
|
|
void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data) { |
|
|
|
void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data) { |
|
|
|
if (inputNode->isa<ValueNode>()) { |
|
|
|
auto valNode = inputNode->cast<ValueNodePtr>(); |
|
|
|
MS_ASSERT(valNode != nullptr); |
|
|
|
@@ -394,12 +394,12 @@ void PrimitiveC::ClearInputOutputQuantParam() { |
|
|
|
output_quant_param_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void PrimitiveC::AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { |
|
|
|
void PrimitiveC::AddInputQuantParam(const std::vector<schema::QuantParamT> &quant_param) { |
|
|
|
this->input_quant_param_.emplace_back(quant_param); |
|
|
|
} |
|
|
|
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::input_quant_params() const { return input_quant_param_; } |
|
|
|
|
|
|
|
void PrimitiveC::AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) { |
|
|
|
void PrimitiveC::AddOutputQuantParam(const std::vector<schema::QuantParamT> &quant_param) { |
|
|
|
this->output_quant_param_.emplace_back(quant_param); |
|
|
|
} |
|
|
|
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::output_quant_params() const { return output_quant_param_; } |
|
|
|
@@ -415,7 +415,7 @@ std::shared_ptr<PrimitiveC> GetReturnPrim() { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return_primitiveT->value.type = schema::PrimitiveType_Return; |
|
|
|
return_primitiveT->value.value = new schema::ReturnT; |
|
|
|
return_primitiveT->value.value = new (std::nothrow) schema::ReturnT; |
|
|
|
if (return_primitiveT->value.value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new ReturnT failed"; |
|
|
|
delete (return_primitiveT); |
|
|
|
@@ -425,13 +425,13 @@ std::shared_ptr<PrimitiveC> GetReturnPrim() { |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<PrimitiveC> GetMakeTuplePrim() { |
|
|
|
auto make_tuple_primitiveT = new schema::PrimitiveT; |
|
|
|
auto make_tuple_primitiveT = new (std::nothrow) schema::PrimitiveT; |
|
|
|
if (make_tuple_primitiveT == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new PrimitiveT failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple; |
|
|
|
make_tuple_primitiveT->value.value = new schema::MakeTupleT; |
|
|
|
make_tuple_primitiveT->value.value = new (std::nothrow) schema::MakeTupleT; |
|
|
|
if (make_tuple_primitiveT->value.value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new MakeTupleT failed"; |
|
|
|
delete (make_tuple_primitiveT); |
|
|
|
@@ -441,13 +441,13 @@ std::shared_ptr<PrimitiveC> GetMakeTuplePrim() { |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() { |
|
|
|
auto tuple_get_item_primitiveT = new schema::PrimitiveT(); |
|
|
|
auto tuple_get_item_primitiveT = new (std::nothrow) schema::PrimitiveT(); |
|
|
|
if (tuple_get_item_primitiveT == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new PrimitiveT failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem; |
|
|
|
tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT; |
|
|
|
tuple_get_item_primitiveT->value.value = new (std::nothrow) schema::TupleGetItemT; |
|
|
|
if (tuple_get_item_primitiveT->value.value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new TupleGetItemT failed"; |
|
|
|
delete (tuple_get_item_primitiveT); |
|
|
|
@@ -642,316 +642,316 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { |
|
|
|
auto op_type = primitive->value.type; |
|
|
|
switch (op_type) { |
|
|
|
case schema::PrimitiveType_SoftMax: |
|
|
|
return new SoftMax(primitive); |
|
|
|
return new (std::nothrow) SoftMax(primitive); |
|
|
|
case schema::PrimitiveType_Activation: |
|
|
|
return new Activation(primitive); |
|
|
|
return new (std::nothrow) Activation(primitive); |
|
|
|
case schema::PrimitiveType_Conv2D: |
|
|
|
return new Conv2D(primitive); |
|
|
|
return new (std::nothrow) Conv2D(primitive); |
|
|
|
case schema::PrimitiveType_DeConv2D: |
|
|
|
return new DeConv2D(primitive); |
|
|
|
return new (std::nothrow) DeConv2D(primitive); |
|
|
|
case schema::PrimitiveType_Reduce: |
|
|
|
return new Reduce(primitive); |
|
|
|
return new (std::nothrow) Reduce(primitive); |
|
|
|
case schema::PrimitiveType_Pooling: |
|
|
|
return new Pooling(primitive); |
|
|
|
return new (std::nothrow) Pooling(primitive); |
|
|
|
case schema::PrimitiveType_ROIPooling: |
|
|
|
return new ROIPooling(primitive); |
|
|
|
return new (std::nothrow) ROIPooling(primitive); |
|
|
|
case schema::PrimitiveType_DepthwiseConv2D: |
|
|
|
return new DepthwiseConv2D(primitive); |
|
|
|
return new (std::nothrow) DepthwiseConv2D(primitive); |
|
|
|
case schema::PrimitiveType_FusedBatchNorm: |
|
|
|
return new FusedBatchNorm(primitive); |
|
|
|
return new (std::nothrow) FusedBatchNorm(primitive); |
|
|
|
case schema::PrimitiveType_BatchNorm: |
|
|
|
return new BatchNorm(primitive); |
|
|
|
return new (std::nothrow) BatchNorm(primitive); |
|
|
|
case schema::PrimitiveType_FullConnection: |
|
|
|
return new FullConnection(primitive); |
|
|
|
return new (std::nothrow) FullConnection(primitive); |
|
|
|
case schema::PrimitiveType_Power: |
|
|
|
return new Power(primitive); |
|
|
|
return new (std::nothrow) Power(primitive); |
|
|
|
case schema::PrimitiveType_Pad: |
|
|
|
return new Pad(primitive); |
|
|
|
return new (std::nothrow) Pad(primitive); |
|
|
|
case schema::PrimitiveType_Range: |
|
|
|
return new Range(primitive); |
|
|
|
return new (std::nothrow) Range(primitive); |
|
|
|
case schema::PrimitiveType_Mul: |
|
|
|
return new Mul(primitive); |
|
|
|
return new (std::nothrow) Mul(primitive); |
|
|
|
case schema::PrimitiveType_Add: |
|
|
|
return new Add(primitive); |
|
|
|
return new (std::nothrow) Add(primitive); |
|
|
|
case schema::PrimitiveType_Sub: |
|
|
|
return new Sub(primitive); |
|
|
|
return new (std::nothrow) Sub(primitive); |
|
|
|
case schema::PrimitiveType_Div: |
|
|
|
return new Div(primitive); |
|
|
|
return new (std::nothrow) Div(primitive); |
|
|
|
case schema::PrimitiveType_BiasAdd: |
|
|
|
return new BiasAdd(primitive); |
|
|
|
return new (std::nothrow) BiasAdd(primitive); |
|
|
|
case schema::PrimitiveType_ExpandDims: |
|
|
|
return new ExpandDims(primitive); |
|
|
|
return new (std::nothrow) ExpandDims(primitive); |
|
|
|
case schema::PrimitiveType_ArgMax: |
|
|
|
return new ArgMax(primitive); |
|
|
|
return new (std::nothrow) ArgMax(primitive); |
|
|
|
case schema::PrimitiveType_ArgMin: |
|
|
|
return new ArgMin(primitive); |
|
|
|
return new (std::nothrow) ArgMin(primitive); |
|
|
|
case schema::PrimitiveType_Cast: |
|
|
|
return new Cast(primitive); |
|
|
|
return new (std::nothrow) Cast(primitive); |
|
|
|
case schema::PrimitiveType_Reshape: |
|
|
|
return new Reshape(primitive); |
|
|
|
return new (std::nothrow) Reshape(primitive); |
|
|
|
case schema::PrimitiveType_Scale: |
|
|
|
return new Scale(primitive); |
|
|
|
return new (std::nothrow) Scale(primitive); |
|
|
|
case schema::PrimitiveType_Eltwise: |
|
|
|
return new Eltwise(primitive); |
|
|
|
return new (std::nothrow) Eltwise(primitive); |
|
|
|
case schema::PrimitiveType_Ceil: |
|
|
|
return new Ceil(primitive); |
|
|
|
return new (std::nothrow) Ceil(primitive); |
|
|
|
case schema::PrimitiveType_Concat: |
|
|
|
return new Concat(primitive); |
|
|
|
return new (std::nothrow) Concat(primitive); |
|
|
|
case schema::PrimitiveType_Fill: |
|
|
|
return new Fill(primitive); |
|
|
|
return new (std::nothrow) Fill(primitive); |
|
|
|
case schema::PrimitiveType_Nhwc2Nchw: |
|
|
|
return new Nhwc2Nchw(primitive); |
|
|
|
return new (std::nothrow) Nhwc2Nchw(primitive); |
|
|
|
case schema::PrimitiveType_Nchw2Nhwc: |
|
|
|
return new Nchw2Nhwc(primitive); |
|
|
|
return new (std::nothrow) Nchw2Nhwc(primitive); |
|
|
|
case schema::PrimitiveType_Transpose: |
|
|
|
return new Transpose(primitive); |
|
|
|
return new (std::nothrow) Transpose(primitive); |
|
|
|
case schema::PrimitiveType_Slice: |
|
|
|
return new Slice(primitive); |
|
|
|
return new (std::nothrow) Slice(primitive); |
|
|
|
case schema::PrimitiveType_Squeeze: |
|
|
|
return new Squeeze(primitive); |
|
|
|
return new (std::nothrow) Squeeze(primitive); |
|
|
|
case schema::PrimitiveType_Flatten: |
|
|
|
return new Flatten(primitive); |
|
|
|
return new (std::nothrow) Flatten(primitive); |
|
|
|
case schema::PrimitiveType_Mean: |
|
|
|
return new Mean(primitive); |
|
|
|
return new (std::nothrow) Mean(primitive); |
|
|
|
case schema::PrimitiveType_Stack: |
|
|
|
return new Stack(primitive); |
|
|
|
return new (std::nothrow) Stack(primitive); |
|
|
|
case schema::PrimitiveType_Crop: |
|
|
|
return new Crop(primitive); |
|
|
|
return new (std::nothrow) Crop(primitive); |
|
|
|
case schema::PrimitiveType_SquaredDifference: |
|
|
|
return new SquaredDifference(primitive); |
|
|
|
return new (std::nothrow) SquaredDifference(primitive); |
|
|
|
case schema::PrimitiveType_AddN: |
|
|
|
return new AddN(primitive); |
|
|
|
return new (std::nothrow) AddN(primitive); |
|
|
|
case schema::PrimitiveType_Abs: |
|
|
|
return new Abs(primitive); |
|
|
|
return new (std::nothrow) Abs(primitive); |
|
|
|
case schema::PrimitiveType_Sin: |
|
|
|
return new Sin(primitive); |
|
|
|
return new (std::nothrow) Sin(primitive); |
|
|
|
case schema::PrimitiveType_Cos: |
|
|
|
return new Cos(primitive); |
|
|
|
return new (std::nothrow) Cos(primitive); |
|
|
|
case schema::PrimitiveType_Log: |
|
|
|
return new Log(primitive); |
|
|
|
return new (std::nothrow) Log(primitive); |
|
|
|
case schema::PrimitiveType_Sqrt: |
|
|
|
return new Sqrt(primitive); |
|
|
|
return new (std::nothrow) Sqrt(primitive); |
|
|
|
case schema::PrimitiveType_Rsqrt: |
|
|
|
return new Rsqrt(primitive); |
|
|
|
return new (std::nothrow) Rsqrt(primitive); |
|
|
|
case schema::PrimitiveType_Square: |
|
|
|
return new Square(primitive); |
|
|
|
return new (std::nothrow) Square(primitive); |
|
|
|
case schema::PrimitiveType_Exp: |
|
|
|
return new Exp(primitive); |
|
|
|
return new (std::nothrow) Exp(primitive); |
|
|
|
case schema::PrimitiveType_Gather: |
|
|
|
return new Gather(primitive); |
|
|
|
return new (std::nothrow) Gather(primitive); |
|
|
|
case schema::PrimitiveType_GatherNd: |
|
|
|
return new GatherNd(primitive); |
|
|
|
return new (std::nothrow) GatherNd(primitive); |
|
|
|
case schema::PrimitiveType_LocalResponseNormalization: |
|
|
|
return new LocalResponseNormalization(primitive); |
|
|
|
return new (std::nothrow) LocalResponseNormalization(primitive); |
|
|
|
case schema::PrimitiveType_Maximum: |
|
|
|
return new Maximum(primitive); |
|
|
|
return new (std::nothrow) Maximum(primitive); |
|
|
|
case schema::PrimitiveType_Minimum: |
|
|
|
return new Minimum(primitive); |
|
|
|
return new (std::nothrow) Minimum(primitive); |
|
|
|
case schema::PrimitiveType_StridedSlice: |
|
|
|
return new StridedSlice(primitive); |
|
|
|
return new (std::nothrow) StridedSlice(primitive); |
|
|
|
case schema::PrimitiveType_LeakyReLU: |
|
|
|
return new (std::nothrow) LeakyReLU(primitive); |
|
|
|
case schema::PrimitiveType_PReLU: |
|
|
|
return new (std::nothrow) PReLU(primitive); |
|
|
|
case schema::PrimitiveType_Round: |
|
|
|
return new Round(primitive); |
|
|
|
return new (std::nothrow) Round(primitive); |
|
|
|
case schema::PrimitiveType_Reverse: |
|
|
|
return new Reverse(primitive); |
|
|
|
return new (std::nothrow) Reverse(primitive); |
|
|
|
case schema::PrimitiveType_ReverseSequence: |
|
|
|
return new ReverseSequence(primitive); |
|
|
|
return new (std::nothrow) ReverseSequence(primitive); |
|
|
|
case schema::PrimitiveType_LogicalAnd: |
|
|
|
return new LogicalAnd(primitive); |
|
|
|
return new (std::nothrow) LogicalAnd(primitive); |
|
|
|
case schema::PrimitiveType_LogicalOr: |
|
|
|
return new LogicalOr(primitive); |
|
|
|
return new (std::nothrow) LogicalOr(primitive); |
|
|
|
case schema::PrimitiveType_LogicalNot: |
|
|
|
return new LogicalNot(primitive); |
|
|
|
return new (std::nothrow) LogicalNot(primitive); |
|
|
|
case schema::PrimitiveType_FloorDiv: |
|
|
|
return new FloorDiv(primitive); |
|
|
|
return new (std::nothrow) FloorDiv(primitive); |
|
|
|
case schema::PrimitiveType_FloorMod: |
|
|
|
return new FloorMod(primitive); |
|
|
|
return new (std::nothrow) FloorMod(primitive); |
|
|
|
case schema::PrimitiveType_Equal: |
|
|
|
return new Equal(primitive); |
|
|
|
return new (std::nothrow) Equal(primitive); |
|
|
|
case schema::PrimitiveType_NotEqual: |
|
|
|
return new NotEqual(primitive); |
|
|
|
return new (std::nothrow) NotEqual(primitive); |
|
|
|
case schema::PrimitiveType_Less: |
|
|
|
return new Less(primitive); |
|
|
|
return new (std::nothrow) Less(primitive); |
|
|
|
case schema::PrimitiveType_LessEqual: |
|
|
|
return new LessEqual(primitive); |
|
|
|
return new (std::nothrow) LessEqual(primitive); |
|
|
|
case schema::PrimitiveType_Greater: |
|
|
|
return new Greater(primitive); |
|
|
|
return new (std::nothrow) Greater(primitive); |
|
|
|
case schema::PrimitiveType_GreaterEqual: |
|
|
|
return new GreaterEqual(primitive); |
|
|
|
return new (std::nothrow) GreaterEqual(primitive); |
|
|
|
case schema::PrimitiveType_Floor: |
|
|
|
return new Floor(primitive); |
|
|
|
return new (std::nothrow) Floor(primitive); |
|
|
|
case schema::PrimitiveType_Split: |
|
|
|
return new Split(primitive); |
|
|
|
return new (std::nothrow) Split(primitive); |
|
|
|
case schema::PrimitiveType_OneHot: |
|
|
|
return new OneHot(primitive); |
|
|
|
return new (std::nothrow) OneHot(primitive); |
|
|
|
case schema::PrimitiveType_PriorBox: |
|
|
|
return new PriorBox(primitive); |
|
|
|
return new (std::nothrow) PriorBox(primitive); |
|
|
|
case schema::PrimitiveType_SpaceToDepth: |
|
|
|
return new SpaceToDepth(primitive); |
|
|
|
return new (std::nothrow) SpaceToDepth(primitive); |
|
|
|
case schema::PrimitiveType_Tile: |
|
|
|
return new Tile(primitive); |
|
|
|
return new (std::nothrow) Tile(primitive); |
|
|
|
case schema::PrimitiveType_Resize: |
|
|
|
return new Resize(primitive); |
|
|
|
return new (std::nothrow) Resize(primitive); |
|
|
|
case schema::PrimitiveType_Unstack: |
|
|
|
return new Unstack(primitive); |
|
|
|
return new (std::nothrow) Unstack(primitive); |
|
|
|
case schema::PrimitiveType_Unique: |
|
|
|
return new Unique(primitive); |
|
|
|
return new (std::nothrow) Unique(primitive); |
|
|
|
case schema::PrimitiveType_TopK: |
|
|
|
return new TopK(primitive); |
|
|
|
return new (std::nothrow) TopK(primitive); |
|
|
|
case schema::PrimitiveType_MatMul: |
|
|
|
return new MatMul(primitive); |
|
|
|
return new (std::nothrow) MatMul(primitive); |
|
|
|
case schema::PrimitiveType_QuantDTypeCast: |
|
|
|
return new QuantDTypeCast(primitive); |
|
|
|
return new (std::nothrow) QuantDTypeCast(primitive); |
|
|
|
case schema::PrimitiveType_EmbeddingLookup: |
|
|
|
return new EmbeddingLookup(primitive); |
|
|
|
return new (std::nothrow) EmbeddingLookup(primitive); |
|
|
|
case schema::PrimitiveType_Elu: |
|
|
|
return new Elu(primitive); |
|
|
|
return new (std::nothrow) Elu(primitive); |
|
|
|
case schema::PrimitiveType_DeDepthwiseConv2D: |
|
|
|
return new DeDepthwiseConv2D(primitive); |
|
|
|
return new (std::nothrow) DeDepthwiseConv2D(primitive); |
|
|
|
case schema::PrimitiveType_Shape: |
|
|
|
return new Shape(primitive); |
|
|
|
return new (std::nothrow) Shape(primitive); |
|
|
|
case schema::PrimitiveType_Unsqueeze: |
|
|
|
return new Unsqueeze(primitive); |
|
|
|
return new (std::nothrow) Unsqueeze(primitive); |
|
|
|
case schema::PrimitiveType_BatchToSpace: |
|
|
|
case schema::PrimitiveType_BatchToSpaceND: |
|
|
|
return new BatchToSpace(primitive); |
|
|
|
return new (std::nothrow) BatchToSpace(primitive); |
|
|
|
case schema::PrimitiveType_SpaceToBatch: |
|
|
|
return new SpaceToBatch(primitive); |
|
|
|
return new (std::nothrow) SpaceToBatch(primitive); |
|
|
|
case schema::PrimitiveType_SpaceToBatchND: |
|
|
|
return new SpaceToBatchND(primitive); |
|
|
|
return new (std::nothrow) SpaceToBatchND(primitive); |
|
|
|
case schema::PrimitiveType_BroadcastTo: |
|
|
|
return new BroadcastTo(primitive); |
|
|
|
return new (std::nothrow) BroadcastTo(primitive); |
|
|
|
case schema::PrimitiveType_DepthToSpace: |
|
|
|
return new DepthToSpace(primitive); |
|
|
|
return new (std::nothrow) DepthToSpace(primitive); |
|
|
|
case schema::PrimitiveType_Lstm: |
|
|
|
return new Lstm(primitive); |
|
|
|
return new (std::nothrow) Lstm(primitive); |
|
|
|
case schema::PrimitiveType_ZerosLike: |
|
|
|
return new ZerosLike(primitive); |
|
|
|
return new (std::nothrow) ZerosLike(primitive); |
|
|
|
case schema::PrimitiveType_MakeTuple: |
|
|
|
return new MakeTuple(primitive); |
|
|
|
return new (std::nothrow) MakeTuple(primitive); |
|
|
|
case schema::PrimitiveType_Where: |
|
|
|
return new Where(primitive); |
|
|
|
return new (std::nothrow) Where(primitive); |
|
|
|
case schema::PrimitiveType_ScatterND: |
|
|
|
return new ScatterND(primitive); |
|
|
|
return new (std::nothrow) ScatterND(primitive); |
|
|
|
case schema::PrimitiveType_ConstantOfShape: |
|
|
|
return new ConstantOfShape(primitive); |
|
|
|
return new (std::nothrow) ConstantOfShape(primitive); |
|
|
|
case schema::PrimitiveType_L2Norm: |
|
|
|
return new L2Norm(primitive); |
|
|
|
return new (std::nothrow) L2Norm(primitive); |
|
|
|
case schema::PrimitiveType_SparseToDense: |
|
|
|
return new SparseToDense(primitive); |
|
|
|
return new (std::nothrow) SparseToDense(primitive); |
|
|
|
case schema::PrimitiveType_DetectionPostProcess: |
|
|
|
return new DetectionPostProcess(primitive); |
|
|
|
return new (std::nothrow) DetectionPostProcess(primitive); |
|
|
|
case schema::PrimitiveType_Dropout: |
|
|
|
return new Dropout(primitive); |
|
|
|
return new (std::nothrow) Dropout(primitive); |
|
|
|
case schema::PrimitiveType_Neg: |
|
|
|
return new Neg(primitive); |
|
|
|
return new (std::nothrow) Neg(primitive); |
|
|
|
case schema::PrimitiveType_RealDiv: |
|
|
|
return new RealDiv(primitive); |
|
|
|
return new (std::nothrow) RealDiv(primitive); |
|
|
|
case schema::PrimitiveType_LshProjection: |
|
|
|
return new LshProjection(primitive); |
|
|
|
return new (std::nothrow) LshProjection(primitive); |
|
|
|
case schema::PrimitiveType_HashtableLookup: |
|
|
|
return new HashtableLookup(primitive); |
|
|
|
return new (std::nothrow) HashtableLookup(primitive); |
|
|
|
case schema::PrimitiveType_SkipGram: |
|
|
|
return new SkipGram(primitive); |
|
|
|
return new (std::nothrow) SkipGram(primitive); |
|
|
|
case schema::PrimitiveType_Clip: |
|
|
|
return new Clip(primitive); |
|
|
|
return new (std::nothrow) Clip(primitive); |
|
|
|
case schema::PrimitiveType_CustomPredict: |
|
|
|
return new CustomPredict(primitive); |
|
|
|
return new (std::nothrow) CustomPredict(primitive); |
|
|
|
case schema::PrimitiveType_CustomNormalize: |
|
|
|
return new CustomNormalize(primitive); |
|
|
|
return new (std::nothrow) CustomNormalize(primitive); |
|
|
|
case schema::PrimitiveType_CustomExtractFeatures: |
|
|
|
return new CustomExtractFeatures(primitive); |
|
|
|
return new (std::nothrow) CustomExtractFeatures(primitive); |
|
|
|
case schema::PrimitiveType_Upsample: |
|
|
|
return new Upsample(primitive); |
|
|
|
return new (std::nothrow) Upsample(primitive); |
|
|
|
case schema::PrimitiveType_LayerNorm: |
|
|
|
return new LayerNorm(primitive); |
|
|
|
return new (std::nothrow) LayerNorm(primitive); |
|
|
|
case schema::PrimitiveType_NonMaxSuppression: |
|
|
|
return new NonMaxSuppression(primitive); |
|
|
|
return new (std::nothrow) NonMaxSuppression(primitive); |
|
|
|
case schema::PrimitiveType_Identity: |
|
|
|
return new Identity(primitive); |
|
|
|
return new (std::nothrow) Identity(primitive); |
|
|
|
case schema::PrimitiveType_Rfft: |
|
|
|
return new Rfft(primitive); |
|
|
|
return new (std::nothrow) Rfft(primitive); |
|
|
|
case schema::PrimitiveType_FftReal: |
|
|
|
return new FftReal(primitive); |
|
|
|
return new (std::nothrow) FftReal(primitive); |
|
|
|
case schema::PrimitiveType_FftImag: |
|
|
|
return new FftImag(primitive); |
|
|
|
return new (std::nothrow) FftImag(primitive); |
|
|
|
case schema::PrimitiveType_AudioSpectrogram: |
|
|
|
return new AudioSpectrogram(primitive); |
|
|
|
return new (std::nothrow) AudioSpectrogram(primitive); |
|
|
|
case schema::PrimitiveType_Mfcc: |
|
|
|
return new Mfcc(primitive); |
|
|
|
return new (std::nothrow) Mfcc(primitive); |
|
|
|
case schema::PrimitiveType_InstanceNorm: |
|
|
|
return new InstanceNorm(primitive); |
|
|
|
return new (std::nothrow) InstanceNorm(primitive); |
|
|
|
case schema::PrimitiveType_While: |
|
|
|
return new While(primitive); |
|
|
|
return new (std::nothrow) While(primitive); |
|
|
|
case schema::PrimitiveType_OnnxInt8Quantize: |
|
|
|
return new Quant(primitive); |
|
|
|
return new (std::nothrow) Quant(primitive); |
|
|
|
case schema::PrimitiveType_OnnxInt8Dequantize: |
|
|
|
return new Dequant(primitive); |
|
|
|
return new (std::nothrow) Dequant(primitive); |
|
|
|
|
|
|
|
#ifdef SUPPORT_TRAIN |
|
|
|
case schema::PrimitiveType_ActivationGrad: |
|
|
|
return new ActivationGrad(primitive); |
|
|
|
return new (std::nothrow) ActivationGrad(primitive); |
|
|
|
case schema::PrimitiveType_PoolingGrad: |
|
|
|
return new PoolingGrad(primitive); |
|
|
|
return new (std::nothrow) PoolingGrad(primitive); |
|
|
|
case schema::PrimitiveType_Conv2DGradFilter: |
|
|
|
return new Conv2DGradFilter(primitive); |
|
|
|
return new (std::nothrow) Conv2DGradFilter(primitive); |
|
|
|
case schema::PrimitiveType_Conv2DGradInput: |
|
|
|
return new Conv2DGradInput(primitive); |
|
|
|
return new (std::nothrow) Conv2DGradInput(primitive); |
|
|
|
case schema::PrimitiveType_GroupConv2DGradInput: |
|
|
|
return new GroupConv2DGradInput(primitive); |
|
|
|
return new (std::nothrow) GroupConv2DGradInput(primitive); |
|
|
|
case schema::PrimitiveType_BiasGrad: |
|
|
|
return new BiasGrad(primitive); |
|
|
|
return new (std::nothrow) BiasGrad(primitive); |
|
|
|
case schema::PrimitiveType_ApplyMomentum: |
|
|
|
return new ApplyMomentum(primitive); |
|
|
|
return new (std::nothrow) ApplyMomentum(primitive); |
|
|
|
case schema::PrimitiveType_BNGrad: |
|
|
|
return new BNGrad(primitive); |
|
|
|
return new (std::nothrow) BNGrad(primitive); |
|
|
|
case schema::PrimitiveType_AddGrad: |
|
|
|
return new ArithmeticGrad(primitive); |
|
|
|
return new (std::nothrow) ArithmeticGrad(primitive); |
|
|
|
case schema::PrimitiveType_SubGrad: |
|
|
|
return new ArithmeticGrad(primitive); |
|
|
|
return new (std::nothrow) ArithmeticGrad(primitive); |
|
|
|
case schema::PrimitiveType_MulGrad: |
|
|
|
return new ArithmeticGrad(primitive); |
|
|
|
return new (std::nothrow) ArithmeticGrad(primitive); |
|
|
|
case schema::PrimitiveType_DivGrad: |
|
|
|
return new ArithmeticGrad(primitive); |
|
|
|
return new (std::nothrow) ArithmeticGrad(primitive); |
|
|
|
case schema::PrimitiveType_SoftmaxCrossEntropy: |
|
|
|
return new SoftmaxCrossEntropy(primitive); |
|
|
|
return new (std::nothrow) SoftmaxCrossEntropy(primitive); |
|
|
|
case schema::PrimitiveType_PowerGrad: |
|
|
|
return new PowerGrad(primitive); |
|
|
|
return new (std::nothrow) PowerGrad(primitive); |
|
|
|
case schema::PrimitiveType_Depend: |
|
|
|
return new Depend(primitive); |
|
|
|
return new (std::nothrow) Depend(primitive); |
|
|
|
case schema::PrimitiveType_ControlDepend: |
|
|
|
return new ControlDepend(primitive); |
|
|
|
return new (std::nothrow) ControlDepend(primitive); |
|
|
|
case schema::PrimitiveType_FlattenGrad: |
|
|
|
return new FlattenGrad(primitive); |
|
|
|
return new (std::nothrow) FlattenGrad(primitive); |
|
|
|
case schema::PrimitiveType_NegGrad: |
|
|
|
return new NegGrad(primitive); |
|
|
|
return new (std::nothrow) NegGrad(primitive); |
|
|
|
case schema::PrimitiveType_LogGrad: |
|
|
|
return new LogGrad(primitive); |
|
|
|
return new (std::nothrow) LogGrad(primitive); |
|
|
|
case schema::PrimitiveType_Sgd: |
|
|
|
return new Sgd(primitive); |
|
|
|
return new (std::nothrow) Sgd(primitive); |
|
|
|
case schema::PrimitiveType_Adam: |
|
|
|
return new Adam(primitive); |
|
|
|
return new (std::nothrow) Adam(primitive); |
|
|
|
case schema::PrimitiveType_Assign: |
|
|
|
return new Assign(primitive); |
|
|
|
return new (std::nothrow) Assign(primitive); |
|
|
|
case schema::PrimitiveType_AssignAdd: |
|
|
|
return new AssignAdd(primitive); |
|
|
|
return new (std::nothrow) AssignAdd(primitive); |
|
|
|
case schema::PrimitiveType_OnesLike: |
|
|
|
return new OnesLike(primitive); |
|
|
|
return new (std::nothrow) OnesLike(primitive); |
|
|
|
case schema::PrimitiveType_UnsortedSegmentSum: |
|
|
|
return new UnsortedSegmentSum(primitive); |
|
|
|
return new (std::nothrow) UnsortedSegmentSum(primitive); |
|
|
|
case schema::PrimitiveType_BinaryCrossEntropyGrad: |
|
|
|
return new BinaryCrossEntropyGrad(primitive); |
|
|
|
return new (std::nothrow) BinaryCrossEntropyGrad(primitive); |
|
|
|
case schema::PrimitiveType_BinaryCrossEntropy: |
|
|
|
return new BinaryCrossEntropy(primitive); |
|
|
|
return new (std::nothrow) BinaryCrossEntropy(primitive); |
|
|
|
case schema::PrimitiveType_DropoutGrad: |
|
|
|
return new DropoutGrad(primitive); |
|
|
|
return new (std::nothrow) DropoutGrad(primitive); |
|
|
|
case schema::PrimitiveType_MaximumGrad: |
|
|
|
return new MaximumGrad(primitive); |
|
|
|
return new (std::nothrow) MaximumGrad(primitive); |
|
|
|
case schema::PrimitiveType_MinimumGrad: |
|
|
|
return new MinimumGrad(primitive); |
|
|
|
return new (std::nothrow) MinimumGrad(primitive); |
|
|
|
#endif |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); |
|
|
|
|