|
|
@@ -72,76 +72,6 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
attr->shrinkAxisMask = attr_value.i(); |
|
|
attr->shrinkAxisMask = attr_value.i(); |
|
|
|
|
|
|
|
|
// begin |
|
|
|
|
|
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1)); |
|
|
|
|
|
if (begin_node == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Find StridedSlice input begin failed"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) { |
|
|
|
|
|
MS_LOG(ERROR) << "The value attr should be specified"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
auto tensor_proto = attr_value.tensor(); |
|
|
|
|
|
if (tensor_proto.int_val_size() > 0) { |
|
|
|
|
|
for (int i = 0; i < tensor_proto.int_val_size(); ++i) { |
|
|
|
|
|
attr->begin.push_back(tensor_proto.int_val(i)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); |
|
|
|
|
|
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); |
|
|
|
|
|
for (size_t i = 0; i < data_num; ++i) { |
|
|
|
|
|
attr->begin.push_back(data[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// end |
|
|
|
|
|
auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2)); |
|
|
|
|
|
if (end_node == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Find StridedSlice input end failed"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) { |
|
|
|
|
|
MS_LOG(ERROR) << "The value attr should be specified"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
tensor_proto = attr_value.tensor(); |
|
|
|
|
|
if (tensor_proto.int_val_size() > 0) { |
|
|
|
|
|
for (int i = 0; i < tensor_proto.int_val_size(); ++i) { |
|
|
|
|
|
attr->end.push_back(tensor_proto.int_val(i)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); |
|
|
|
|
|
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); |
|
|
|
|
|
for (size_t i = 0; i < data_num; ++i) { |
|
|
|
|
|
attr->end.push_back(data[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// strides |
|
|
|
|
|
auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3)); |
|
|
|
|
|
if (stride_node == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Find StridedSlice input strides failed"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) { |
|
|
|
|
|
MS_LOG(ERROR) << "The value attr should be specified"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
tensor_proto = attr_value.tensor(); |
|
|
|
|
|
if (tensor_proto.int_val_size() > 0) { |
|
|
|
|
|
for (int i = 0; i < tensor_proto.int_val_size(); ++i) { |
|
|
|
|
|
attr->stride.push_back(tensor_proto.int_val(i)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); |
|
|
|
|
|
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); |
|
|
|
|
|
for (size_t i = 0; i < data_num; ++i) { |
|
|
|
|
|
attr->stride.push_back(data[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_StridedSlice; |
|
|
primitive->value.type = schema::PrimitiveType_StridedSlice; |
|
|
primitive->value.value = attr.release(); |
|
|
primitive->value.value = attr.release(); |
|
|
*primitiveC = PrimitiveC::Create(primitive.release()); |
|
|
*primitiveC = PrimitiveC::Create(primitive.release()); |
|
|
@@ -151,7 +81,14 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
*output_size = 1; |
|
|
*output_size = 1; |
|
|
auto status = AddOpInput(tf_op, 0, inputs); |
|
|
|
|
|
|
|
|
STATUS status = RET_OK; |
|
|
|
|
|
for (int i = 0; i < tf_op.input_size(); i++) { |
|
|
|
|
|
status = AddOpInput(tf_op, i, inputs); |
|
|
|
|
|
if (status != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "Add Op input failed."; |
|
|
|
|
|
return status; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return status; |
|
|
return status; |
|
|
} |
|
|
} |
|
|
TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser()); |
|
|
TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser()); |
|
|
|