| @@ -106,6 +106,7 @@ endif() # NOT ENABLE_ACL | |||||
| if (ENABLE_SERVING) | if (ENABLE_SERVING) | ||||
| add_subdirectory(serving) | add_subdirectory(serving) | ||||
| add_subdirectory(serving/example/cpp_client) | |||||
| endif() | endif() | ||||
| if (NOT ENABLE_ACL) | if (NOT ENABLE_ACL) | ||||
| @@ -1,318 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
| # source: ms_service.proto | |||||
| from google.protobuf.internal import enum_type_wrapper | |||||
| from google.protobuf import descriptor as _descriptor | |||||
| from google.protobuf import message as _message | |||||
| from google.protobuf import reflection as _reflection | |||||
| from google.protobuf import symbol_database as _symbol_database | |||||
| # @@protoc_insertion_point(imports) | |||||
| _sym_db = _symbol_database.Default() | |||||
| DESCRIPTOR = _descriptor.FileDescriptor( | |||||
| name='ms_service.proto', | |||||
| package='ms_serving', | |||||
| syntax='proto3', | |||||
| serialized_options=None, | |||||
| serialized_pb=b'\n\x10ms_service.proto\x12\nms_serving\"2\n\x0ePredictRequest\x12 \n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"2\n\x0cPredictReply\x12\"\n\x06result\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"\x1b\n\x0bTensorShape\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\"p\n\x06Tensor\x12-\n\x0ctensor_shape\x18\x01 \x01(\x0b\x32\x17.ms_serving.TensorShape\x12)\n\x0btensor_type\x18\x02 \x01(\x0e\x32\x14.ms_serving.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c*\xc9\x01\n\x08\x44\x61taType\x12\x0e\n\nMS_UNKNOWN\x10\x00\x12\x0b\n\x07MS_BOOL\x10\x01\x12\x0b\n\x07MS_INT8\x10\x02\x12\x0c\n\x08MS_UINT8\x10\x03\x12\x0c\n\x08MS_INT16\x10\x04\x12\r\n\tMS_UINT16\x10\x05\x12\x0c\n\x08MS_INT32\x10\x06\x12\r\n\tMS_UINT32\x10\x07\x12\x0c\n\x08MS_INT64\x10\x08\x12\r\n\tMS_UINT64\x10\t\x12\x0e\n\nMS_FLOAT16\x10\n\x12\x0e\n\nMS_FLOAT32\x10\x0b\x12\x0e\n\nMS_FLOAT64\x10\x0c\x32\x8e\x01\n\tMSService\x12\x41\n\x07Predict\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x12>\n\x04Test\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x62\x06proto3' | |||||
| ) | |||||
| _DATATYPE = _descriptor.EnumDescriptor( | |||||
| name='DataType', | |||||
| full_name='ms_serving.DataType', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| values=[ | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_UNKNOWN', index=0, number=0, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_BOOL', index=1, number=1, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_INT8', index=2, number=2, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_UINT8', index=3, number=3, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_INT16', index=4, number=4, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_UINT16', index=5, number=5, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_INT32', index=6, number=6, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_UINT32', index=7, number=7, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_INT64', index=8, number=8, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_UINT64', index=9, number=9, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_FLOAT16', index=10, number=10, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_FLOAT32', index=11, number=11, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='MS_FLOAT64', index=12, number=12, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| ], | |||||
| containing_type=None, | |||||
| serialized_options=None, | |||||
| serialized_start=280, | |||||
| serialized_end=481, | |||||
| ) | |||||
| _sym_db.RegisterEnumDescriptor(_DATATYPE) | |||||
| DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) | |||||
| MS_UNKNOWN = 0 | |||||
| MS_BOOL = 1 | |||||
| MS_INT8 = 2 | |||||
| MS_UINT8 = 3 | |||||
| MS_INT16 = 4 | |||||
| MS_UINT16 = 5 | |||||
| MS_INT32 = 6 | |||||
| MS_UINT32 = 7 | |||||
| MS_INT64 = 8 | |||||
| MS_UINT64 = 9 | |||||
| MS_FLOAT16 = 10 | |||||
| MS_FLOAT32 = 11 | |||||
| MS_FLOAT64 = 12 | |||||
| _PREDICTREQUEST = _descriptor.Descriptor( | |||||
| name='PredictRequest', | |||||
| full_name='ms_serving.PredictRequest', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='data', full_name='ms_serving.PredictRequest.data', index=0, | |||||
| number=1, type=11, cpp_type=10, label=3, | |||||
| has_default_value=False, default_value=[], | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=32, | |||||
| serialized_end=82, | |||||
| ) | |||||
| _PREDICTREPLY = _descriptor.Descriptor( | |||||
| name='PredictReply', | |||||
| full_name='ms_serving.PredictReply', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='result', full_name='ms_serving.PredictReply.result', index=0, | |||||
| number=1, type=11, cpp_type=10, label=3, | |||||
| has_default_value=False, default_value=[], | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=84, | |||||
| serialized_end=134, | |||||
| ) | |||||
| _TENSORSHAPE = _descriptor.Descriptor( | |||||
| name='TensorShape', | |||||
| full_name='ms_serving.TensorShape', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='dims', full_name='ms_serving.TensorShape.dims', index=0, | |||||
| number=1, type=3, cpp_type=2, label=3, | |||||
| has_default_value=False, default_value=[], | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=136, | |||||
| serialized_end=163, | |||||
| ) | |||||
| _TENSOR = _descriptor.Descriptor( | |||||
| name='Tensor', | |||||
| full_name='ms_serving.Tensor', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='tensor_shape', full_name='ms_serving.Tensor.tensor_shape', index=0, | |||||
| number=1, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='tensor_type', full_name='ms_serving.Tensor.tensor_type', index=1, | |||||
| number=2, type=14, cpp_type=8, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='data', full_name='ms_serving.Tensor.data', index=2, | |||||
| number=3, type=12, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"", | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=165, | |||||
| serialized_end=277, | |||||
| ) | |||||
| _PREDICTREQUEST.fields_by_name['data'].message_type = _TENSOR | |||||
| _PREDICTREPLY.fields_by_name['result'].message_type = _TENSOR | |||||
| _TENSOR.fields_by_name['tensor_shape'].message_type = _TENSORSHAPE | |||||
| _TENSOR.fields_by_name['tensor_type'].enum_type = _DATATYPE | |||||
| DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST | |||||
| DESCRIPTOR.message_types_by_name['PredictReply'] = _PREDICTREPLY | |||||
| DESCRIPTOR.message_types_by_name['TensorShape'] = _TENSORSHAPE | |||||
| DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR | |||||
| DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE | |||||
| _sym_db.RegisterFileDescriptor(DESCRIPTOR) | |||||
| PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _PREDICTREQUEST, | |||||
| '__module__' : 'ms_service_pb2' | |||||
| # @@protoc_insertion_point(class_scope:ms_serving.PredictRequest) | |||||
| }) | |||||
| _sym_db.RegisterMessage(PredictRequest) | |||||
| PredictReply = _reflection.GeneratedProtocolMessageType('PredictReply', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _PREDICTREPLY, | |||||
| '__module__' : 'ms_service_pb2' | |||||
| # @@protoc_insertion_point(class_scope:ms_serving.PredictReply) | |||||
| }) | |||||
| _sym_db.RegisterMessage(PredictReply) | |||||
| TensorShape = _reflection.GeneratedProtocolMessageType('TensorShape', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _TENSORSHAPE, | |||||
| '__module__' : 'ms_service_pb2' | |||||
| # @@protoc_insertion_point(class_scope:ms_serving.TensorShape) | |||||
| }) | |||||
| _sym_db.RegisterMessage(TensorShape) | |||||
| Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _TENSOR, | |||||
| '__module__' : 'ms_service_pb2' | |||||
| # @@protoc_insertion_point(class_scope:ms_serving.Tensor) | |||||
| }) | |||||
| _sym_db.RegisterMessage(Tensor) | |||||
| _MSSERVICE = _descriptor.ServiceDescriptor( | |||||
| name='MSService', | |||||
| full_name='ms_serving.MSService', | |||||
| file=DESCRIPTOR, | |||||
| index=0, | |||||
| serialized_options=None, | |||||
| serialized_start=484, | |||||
| serialized_end=626, | |||||
| methods=[ | |||||
| _descriptor.MethodDescriptor( | |||||
| name='Predict', | |||||
| full_name='ms_serving.MSService.Predict', | |||||
| index=0, | |||||
| containing_service=None, | |||||
| input_type=_PREDICTREQUEST, | |||||
| output_type=_PREDICTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| _descriptor.MethodDescriptor( | |||||
| name='Test', | |||||
| full_name='ms_serving.MSService.Test', | |||||
| index=1, | |||||
| containing_service=None, | |||||
| input_type=_PREDICTREQUEST, | |||||
| output_type=_PREDICTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| ]) | |||||
| _sym_db.RegisterServiceDescriptor(_MSSERVICE) | |||||
| DESCRIPTOR.services_by_name['MSService'] = _MSSERVICE | |||||
| # @@protoc_insertion_point(module_scope) | |||||
| @@ -1,96 +0,0 @@ | |||||
| # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | |||||
| import grpc | |||||
| import ms_service_pb2 as ms__service__pb2 | |||||
| class MSServiceStub(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def __init__(self, channel): | |||||
| """Constructor. | |||||
| Args: | |||||
| channel: A grpc.Channel. | |||||
| """ | |||||
| self.Predict = channel.unary_unary( | |||||
| '/ms_serving.MSService/Predict', | |||||
| request_serializer=ms__service__pb2.PredictRequest.SerializeToString, | |||||
| response_deserializer=ms__service__pb2.PredictReply.FromString, | |||||
| ) | |||||
| self.Test = channel.unary_unary( | |||||
| '/ms_serving.MSService/Test', | |||||
| request_serializer=ms__service__pb2.PredictRequest.SerializeToString, | |||||
| response_deserializer=ms__service__pb2.PredictReply.FromString, | |||||
| ) | |||||
| class MSServiceServicer(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def Predict(self, request, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def Test(self, request, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def add_MSServiceServicer_to_server(servicer, server): | |||||
| rpc_method_handlers = { | |||||
| 'Predict': grpc.unary_unary_rpc_method_handler( | |||||
| servicer.Predict, | |||||
| request_deserializer=ms__service__pb2.PredictRequest.FromString, | |||||
| response_serializer=ms__service__pb2.PredictReply.SerializeToString, | |||||
| ), | |||||
| 'Test': grpc.unary_unary_rpc_method_handler( | |||||
| servicer.Test, | |||||
| request_deserializer=ms__service__pb2.PredictRequest.FromString, | |||||
| response_serializer=ms__service__pb2.PredictReply.SerializeToString, | |||||
| ), | |||||
| } | |||||
| generic_handler = grpc.method_handlers_generic_handler( | |||||
| 'ms_serving.MSService', rpc_method_handlers) | |||||
| server.add_generic_rpc_handlers((generic_handler,)) | |||||
| # This class is part of an EXPERIMENTAL API. | |||||
| class MSService(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| @staticmethod | |||||
| def Predict(request, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Predict', | |||||
| ms__service__pb2.PredictRequest.SerializeToString, | |||||
| ms__service__pb2.PredictReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | |||||
| def Test(request, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Test', | |||||
| ms__service__pb2.PredictRequest.SerializeToString, | |||||
| ms__service__pb2.PredictReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @@ -1,67 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #include <grpcpp/health_check_service_interface.h> | |||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||||
| #include <iostream> | |||||
| #include "./ms_service.grpc.pb.h" | |||||
| using grpc::Server; | |||||
| using grpc::ServerBuilder; | |||||
| using grpc::ServerContext; | |||||
| using grpc::Status; | |||||
| using ms_serving::MSService; | |||||
| using ms_serving::PredictReply; | |||||
| using ms_serving::PredictRequest; | |||||
| // Logic and data behind the server's behavior. | |||||
| class MSServiceImpl final : public MSService::Service { | |||||
| Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override { | |||||
| std::cout << "server eval" << std::endl; | |||||
| return Status::OK; | |||||
| } | |||||
| }; | |||||
| void RunServer() { | |||||
| std::string server_address("0.0.0.0:50051"); | |||||
| MSServiceImpl service; | |||||
| grpc::EnableDefaultHealthCheckService(true); | |||||
| grpc::reflection::InitProtoReflectionServerBuilderPlugin(); | |||||
| auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); | |||||
| ServerBuilder builder; | |||||
| builder.SetOption(std::move(option)); | |||||
| // Listen on the given address without any authentication mechanism. | |||||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | |||||
| // Register "service" as the instance through which we'll communicate with | |||||
| // clients. In this case it corresponds to an *synchronous* service. | |||||
| builder.RegisterService(&service); | |||||
| // Finally assemble the server. | |||||
| std::unique_ptr<Server> server(builder.BuildAndStart()); | |||||
| std::cout << "Server listening on " << server_address << std::endl; | |||||
| // Wait for the server to shutdown. Note that some other thread must be | |||||
| // responsible for shutting down the server for this call to ever return. | |||||
| server->Wait(); | |||||
| } | |||||
| int main(int argc, char **argv) { | |||||
| RunServer(); | |||||
| return 0; | |||||
| } | |||||
| @@ -1,6 +1,6 @@ | |||||
| cmake_minimum_required(VERSION 3.5.1) | cmake_minimum_required(VERSION 3.5.1) | ||||
| project(HelloWorld C CXX) | |||||
| project(MSClient C CXX) | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | ||||
| add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) | add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) | ||||
| @@ -12,17 +12,33 @@ find_package(Threads REQUIRED) | |||||
| # Find Protobuf installation | # Find Protobuf installation | ||||
| # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. | # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. | ||||
| set(protobuf_MODULE_COMPATIBLE TRUE) | |||||
| find_package(Protobuf CONFIG REQUIRED) | |||||
| message(STATUS "Using protobuf ${protobuf_VERSION}") | |||||
| option(GRPC_PATH "set grpc path") | |||||
| if(GRPC_PATH) | |||||
| set(CMAKE_PREFIX_PATH ${GRPC_PATH}) | |||||
| set(protobuf_MODULE_COMPATIBLE TRUE) | |||||
| find_package(Protobuf CONFIG REQUIRED) | |||||
| message(STATUS "Using protobuf ${protobuf_VERSION}, CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}") | |||||
| elseif(NOT GRPC_PATH) | |||||
| if (EXISTS ${grpc_ROOT}/lib64) | |||||
| set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc") | |||||
| elseif(EXISTS ${grpc_ROOT}/lib) | |||||
| set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc") | |||||
| endif() | |||||
| add_library(protobuf::libprotobuf ALIAS protobuf::protobuf) | |||||
| add_executable(protobuf::libprotoc ALIAS protobuf::protoc) | |||||
| message(STATUS "serving using grpc_DIR : " ${gRPC_DIR}) | |||||
| elseif(NOT gRPC_DIR AND NOT GRPC_PATH) | |||||
| message("please check gRPC. If the client is compiled separately,you can use the command: cmake -D GRPC_PATH=xxx") | |||||
| message("XXX is the gRPC installation path") | |||||
| endif() | |||||
| set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) | set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) | ||||
| set(_REFLECTION gRPC::grpc++_reflection) | set(_REFLECTION gRPC::grpc++_reflection) | ||||
| if (CMAKE_CROSSCOMPILING) | |||||
| find_program(_PROTOBUF_PROTOC protoc) | |||||
| else () | |||||
| set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>) | |||||
| endif () | |||||
| if(CMAKE_CROSSCOMPILING) | |||||
| find_program(_PROTOBUF_PROTOC protoc) | |||||
| else() | |||||
| set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>) | |||||
| endif() | |||||
| # Find gRPC installation | # Find gRPC installation | ||||
| # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. | # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. | ||||
| @@ -30,14 +46,14 @@ find_package(gRPC CONFIG REQUIRED) | |||||
| message(STATUS "Using gRPC ${gRPC_VERSION}") | message(STATUS "Using gRPC ${gRPC_VERSION}") | ||||
| set(_GRPC_GRPCPP gRPC::grpc++) | set(_GRPC_GRPCPP gRPC::grpc++) | ||||
| if (CMAKE_CROSSCOMPILING) | |||||
| find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) | |||||
| else () | |||||
| set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>) | |||||
| endif () | |||||
| if(CMAKE_CROSSCOMPILING) | |||||
| find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) | |||||
| else() | |||||
| set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>) | |||||
| endif() | |||||
| # Proto file | # Proto file | ||||
| get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE) | |||||
| get_filename_component(hw_proto "../../ms_service.proto" ABSOLUTE) | |||||
| get_filename_component(hw_proto_path "${hw_proto}" PATH) | get_filename_component(hw_proto_path "${hw_proto}" PATH) | ||||
| # Generated sources | # Generated sources | ||||
| @@ -59,13 +75,13 @@ add_custom_command( | |||||
| include_directories("${CMAKE_CURRENT_BINARY_DIR}") | include_directories("${CMAKE_CURRENT_BINARY_DIR}") | ||||
| # Targets greeter_[async_](client|server) | # Targets greeter_[async_](client|server) | ||||
| foreach (_target | |||||
| ms_client ms_server) | |||||
| add_executable(${_target} "${_target}.cc" | |||||
| ${hw_proto_srcs} | |||||
| ${hw_grpc_srcs}) | |||||
| target_link_libraries(${_target} | |||||
| ${_REFLECTION} | |||||
| ${_GRPC_GRPCPP} | |||||
| ${_PROTOBUF_LIBPROTOBUF}) | |||||
| endforeach () | |||||
| foreach(_target | |||||
| ms_client) | |||||
| add_executable(${_target} "${_target}.cc" | |||||
| ${hw_proto_srcs} | |||||
| ${hw_grpc_srcs}) | |||||
| target_link_libraries(${_target} | |||||
| ${_REFLECTION} | |||||
| ${_GRPC_GRPCPP} | |||||
| ${_PROTOBUF_LIBPROTOBUF}) | |||||
| endforeach() | |||||
| @@ -211,77 +211,12 @@ PredictRequest ReadBertInput() { | |||||
| return request; | return request; | ||||
| } | } | ||||
| PredictRequest ReadLenetInput() { | |||||
| size_t size; | |||||
| auto buf = ReadFile("lenet_img.bin", &size); | |||||
| if (buf == nullptr) { | |||||
| std::cout << "read file failed" << std::endl; | |||||
| return PredictRequest(); | |||||
| } | |||||
| PredictRequest request; | |||||
| auto cur = buf; | |||||
| if (size > 0) { | |||||
| Tensor data; | |||||
| TensorShape shape; | |||||
| // set type | |||||
| data.set_tensor_type(ms_serving::MS_FLOAT32); | |||||
| // set shape | |||||
| shape.add_dims(size / sizeof(float)); | |||||
| *data.mutable_tensor_shape() = shape; | |||||
| // set data | |||||
| data.set_data(cur, size); | |||||
| *request.add_data() = data; | |||||
| } | |||||
| std::cout << "get input data size " << size << std::endl; | |||||
| return request; | |||||
| } | |||||
| PredictRequest ReadOtherInput(const std::string &data_file) { | |||||
| size_t size; | |||||
| auto buf = ReadFile(data_file.c_str(), &size); | |||||
| if (buf == nullptr) { | |||||
| std::cout << "read file failed" << std::endl; | |||||
| return PredictRequest(); | |||||
| } | |||||
| PredictRequest request; | |||||
| auto cur = buf; | |||||
| if (size > 0) { | |||||
| Tensor data; | |||||
| TensorShape shape; | |||||
| // set type | |||||
| data.set_tensor_type(ms_serving::MS_FLOAT32); | |||||
| // set shape | |||||
| shape.add_dims(size / sizeof(float)); | |||||
| *data.mutable_tensor_shape() = shape; | |||||
| // set data | |||||
| data.set_data(cur, size); | |||||
| *request.add_data() = data; | |||||
| } | |||||
| std::cout << "get input data size " << size << std::endl; | |||||
| return request; | |||||
| } | |||||
| template <class DT> | |||||
| void print_array_item(const DT *data, size_t size) { | |||||
| for (size_t i = 0; i < size && i < 100; i++) { | |||||
| std::cout << data[i] << '\t'; | |||||
| if ((i + 1) % 10 == 0) { | |||||
| std::cout << std::endl; | |||||
| } | |||||
| } | |||||
| std::cout << std::endl; | |||||
| } | |||||
| class MSClient { | class MSClient { | ||||
| public: | public: | ||||
| explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | ||||
| ~MSClient() = default; | ~MSClient() = default; | ||||
| std::string Predict(const std::string &type, const std::string &data_file) { | |||||
| std::string Predict(const std::string &type) { | |||||
| // Data we are sending to the server. | // Data we are sending to the server. | ||||
| PredictRequest request; | PredictRequest request; | ||||
| if (type == "add") { | if (type == "add") { | ||||
| @@ -299,10 +234,6 @@ class MSClient { | |||||
| *request.add_data() = data; | *request.add_data() = data; | ||||
| } else if (type == "bert") { | } else if (type == "bert") { | ||||
| request = ReadBertInput(); | request = ReadBertInput(); | ||||
| } else if (type == "lenet") { | |||||
| request = ReadLenetInput(); | |||||
| } else if (type == "other") { | |||||
| request = ReadOtherInput(data_file); | |||||
| } else { | } else { | ||||
| std::cout << "type only support bert or add, but input is " << type << std::endl; | std::cout << "type only support bert or add, but input is " << type << std::endl; | ||||
| } | } | ||||
| @@ -325,20 +256,6 @@ class MSClient { | |||||
| // Act upon its status. | // Act upon its status. | ||||
| if (status.ok()) { | if (status.ok()) { | ||||
| for (size_t i = 0; i < reply.result_size(); i++) { | |||||
| auto result = reply.result(i); | |||||
| if (result.tensor_type() == ms_serving::DataType::MS_FLOAT32) { | |||||
| print_array_item(reinterpret_cast<const float *>(result.data().data()), result.data().size() / sizeof(float)); | |||||
| } else if (result.tensor_type() == ms_serving::DataType::MS_INT32) { | |||||
| print_array_item(reinterpret_cast<const int32_t *>(result.data().data()), | |||||
| result.data().size() / sizeof(int32_t)); | |||||
| } else if (result.tensor_type() == ms_serving::DataType::MS_UINT32) { | |||||
| print_array_item(reinterpret_cast<const uint32_t *>(result.data().data()), | |||||
| result.data().size() / sizeof(uint32_t)); | |||||
| } else { | |||||
| std::cout << "output datatype " << result.tensor_type() << std::endl; | |||||
| } | |||||
| } | |||||
| return "RPC OK"; | return "RPC OK"; | ||||
| } else { | } else { | ||||
| std::cout << status.error_code() << ": " << status.error_message() << std::endl; | std::cout << status.error_code() << ": " << status.error_message() << std::endl; | ||||
| @@ -360,8 +277,6 @@ int main(int argc, char **argv) { | |||||
| std::string arg_target_str("--target"); | std::string arg_target_str("--target"); | ||||
| std::string type; | std::string type; | ||||
| std::string arg_type_str("--type"); | std::string arg_type_str("--type"); | ||||
| std::string arg_data_str("--data"); | |||||
| std::string data = "default_data.bin"; | |||||
| if (argc > 2) { | if (argc > 2) { | ||||
| { | { | ||||
| // parse target | // parse target | ||||
| @@ -389,33 +304,19 @@ int main(int argc, char **argv) { | |||||
| if (arg_val2[start_pos] == '=') { | if (arg_val2[start_pos] == '=') { | ||||
| type = arg_val2.substr(start_pos + 1); | type = arg_val2.substr(start_pos + 1); | ||||
| } else { | } else { | ||||
| std::cout << "The only correct argument syntax is --type=" << std::endl; | |||||
| std::cout << "The only correct argument syntax is --target=" << std::endl; | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| } else { | } else { | ||||
| type = "add"; | type = "add"; | ||||
| } | } | ||||
| } | } | ||||
| if (argc > 3) { | |||||
| // parse type | |||||
| std::string arg_val3 = argv[3]; | |||||
| size_t start_pos = arg_val3.find(arg_data_str); | |||||
| if (start_pos != std::string::npos) { | |||||
| start_pos += arg_data_str.size(); | |||||
| if (arg_val3[start_pos] == '=') { | |||||
| data = arg_val3.substr(start_pos + 1); | |||||
| } else { | |||||
| std::cout << "The only correct argument syntax is --data=" << std::endl; | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| target_str = "localhost:5500"; | target_str = "localhost:5500"; | ||||
| type = "add"; | type = "add"; | ||||
| } | } | ||||
| MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); | MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); | ||||
| std::string reply = client.Predict(type, data); | |||||
| std::string reply = client.Predict(type); | |||||
| std::cout << "client received: " << reply << std::endl; | std::cout << "client received: " << reply << std::endl; | ||||
| return 0; | return 0; | ||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import export | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, x_, y_): | |||||
| return self.add(x_, y_) | |||||
| x = np.ones(4).astype(np.float32) | |||||
| y = np.ones(4).astype(np.float32) | |||||
| def export_net(): | |||||
| add = Net() | |||||
| output = add(Tensor(x), Tensor(y)) | |||||
| export(add, Tensor(x), Tensor(y), file_name='tensor_add.pb', file_format='BINARY') | |||||
| print(x) | |||||
| print(y) | |||||
| print(output.asnumpy()) | |||||
| if __name__ == "__main__": | |||||
| export_net() | |||||
| @@ -19,28 +19,25 @@ import ms_service_pb2_grpc | |||||
| def run(): | def run(): | ||||
| channel = grpc.insecure_channel('localhost:50051') | |||||
| channel = grpc.insecure_channel('localhost:5050') | |||||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | stub = ms_service_pb2_grpc.MSServiceStub(channel) | ||||
| # request = ms_service_pb2.EvalRequest() | |||||
| # request.name = 'haha' | |||||
| # response = stub.Eval(request) | |||||
| # print("ms client received: " + response.message) | |||||
| request = ms_service_pb2.PredictRequest() | request = ms_service_pb2.PredictRequest() | ||||
| request.data.tensor_shape.dims.extend([32, 1, 32, 32]) | |||||
| request.data.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes() | |||||
| request.label.tensor_shape.dims.extend([32]) | |||||
| request.label.tensor_type = ms_service_pb2.MS_INT32 | |||||
| request.label.data = np.ones([32]).astype(np.int32).tobytes() | |||||
| result = stub.Test(request) | |||||
| #result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims) | |||||
| print("ms client test call received: ") | |||||
| #print(result_np) | |||||
| x = request.data.add() | |||||
| x.tensor_shape.dims.extend([4]) | |||||
| x.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| x.data = (np.ones([4]).astype(np.float32)).tobytes() | |||||
| y = request.data.add() | |||||
| y.tensor_shape.dims.extend([4]) | |||||
| y.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| y.data = (np.ones([4]).astype(np.float32)).tobytes() | |||||
| result = stub.Predict(request) | |||||
| print(result) | |||||
| result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) | |||||
| print("ms client received: ") | |||||
| print(result_np) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| run() | run() | ||||
| @@ -1,57 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import grpc | |||||
| import numpy as np | |||||
| import ms_service_pb2 | |||||
| import ms_service_pb2_grpc | |||||
| def run(): | |||||
| channel = grpc.insecure_channel('localhost:50051') | |||||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | |||||
| # request = ms_service_pb2.PredictRequest() | |||||
| # request.name = 'haha' | |||||
| # response = stub.Eval(request) | |||||
| # print("ms client received: " + response.message) | |||||
| request = ms_service_pb2.PredictRequest() | |||||
| request.data.tensor_shape.dims.extend([32, 1, 32, 32]) | |||||
| request.data.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes() | |||||
| request.label.tensor_shape.dims.extend([32]) | |||||
| request.label.tensor_type = ms_service_pb2.MS_INT32 | |||||
| request.label.data = np.ones([32]).astype(np.int32).tobytes() | |||||
| result = stub.Predict(request) | |||||
| #result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims) | |||||
| print("ms client received: ") | |||||
| #print(result_np) | |||||
| # future_list = [] | |||||
| # times = 1000 | |||||
| # for i in range(times): | |||||
| # async_future = stub.Eval.future(request) | |||||
| # future_list.append(async_future) | |||||
| # print("async call, future list add item " + str(i)); | |||||
| # | |||||
| # for i in range(len(future_list)): | |||||
| # async_result = future_list[i].result() | |||||
| # print("ms client async get result of item " + str(i)) | |||||
| if __name__ == '__main__': | |||||
| run() | |||||
| @@ -1,55 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from concurrent import futures | |||||
| import time | |||||
| import grpc | |||||
| import numpy as np | |||||
| import ms_service_pb2 | |||||
| import ms_service_pb2_grpc | |||||
| import test_cpu_lenet | |||||
| from mindspore import Tensor | |||||
| class MSService(ms_service_pb2_grpc.MSServiceServicer): | |||||
| def Predict(self, request, context): | |||||
| request_data = request.data | |||||
| request_label = request.label | |||||
| data_from_buffer = np.frombuffer(request_data.data, dtype=np.float32) | |||||
| data_from_buffer = data_from_buffer.reshape(request_data.tensor_shape.dims) | |||||
| data = Tensor(data_from_buffer) | |||||
| label_from_buffer = np.frombuffer(request_label.data, dtype=np.int32) | |||||
| label_from_buffer = label_from_buffer.reshape(request_label.tensor_shape.dims) | |||||
| label = Tensor(label_from_buffer) | |||||
| result = test_cpu_lenet.test_lenet(data, label) | |||||
| result_reply = ms_service_pb2.PredictReply() | |||||
| result_reply.result.tensor_shape.dims.extend(result.shape()) | |||||
| result_reply.result.data = result.asnumpy().tobytes() | |||||
| return result_reply | |||||
| def serve(): | |||||
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) | |||||
| ms_service_pb2_grpc.add_MSServiceServicer_to_server(MSService(), server) | |||||
| server.add_insecure_port('[::]:50051') | |||||
| server.start() | |||||
| try: | |||||
| while True: | |||||
| time.sleep(60*60*24) # one day in seconds | |||||
| except KeyboardInterrupt: | |||||
| server.stop(0) | |||||
| if __name__ == '__main__': | |||||
| serve() | |||||
| @@ -1,91 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import Momentum | |||||
| from mindspore.ops import operations as P | |||||
| import ms_service_pb2 | |||||
| class LeNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(LeNet, self).__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.batch_size = 32 | |||||
| self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | |||||
| self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | |||||
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.reshape = P.Reshape() | |||||
| self.fc1 = nn.Dense(400, 120) | |||||
| self.fc2 = nn.Dense(120, 84) | |||||
| self.fc3 = nn.Dense(84, 10) | |||||
| def construct(self, input_x): | |||||
| output = self.conv1(input_x) | |||||
| output = self.relu(output) | |||||
| output = self.pool(output) | |||||
| output = self.conv2(output) | |||||
| output = self.relu(output) | |||||
| output = self.pool(output) | |||||
| output = self.reshape(output, (self.batch_size, -1)) | |||||
| output = self.fc1(output) | |||||
| output = self.relu(output) | |||||
| output = self.fc2(output) | |||||
| output = self.relu(output) | |||||
| output = self.fc3(output) | |||||
| return output | |||||
| def train(net, data, label): | |||||
| learning_rate = 0.01 | |||||
| momentum = 0.9 | |||||
| optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| net_with_criterion = WithLossCell(net, criterion) | |||||
| train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer | |||||
| train_network.set_train() | |||||
| res = train_network(data, label) | |||||
| print("+++++++++Loss+++++++++++++") | |||||
| print(res) | |||||
| print("+++++++++++++++++++++++++++") | |||||
| assert res | |||||
| return res | |||||
| def test_lenet(data, label): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| net = LeNet() | |||||
| return train(net, data, label) | |||||
| if __name__ == '__main__': | |||||
| tensor = ms_service_pb2.Tensor() | |||||
| tensor.tensor_shape.dim.extend([32, 1, 32, 32]) | |||||
| # tensor.tensor_shape.dim.add() = 1 | |||||
| # tensor.tensor_shape.dim.add() = 32 | |||||
| # tensor.tensor_shape.dim.add() = 32 | |||||
| tensor.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| tensor.data = np.ones([32, 1, 32, 32]).astype(np.float32).tobytes() | |||||
| data_from_buffer = np.frombuffer(tensor.data, dtype=np.float32) | |||||
| print(tensor.tensor_shape.dim) | |||||
| data_from_buffer = data_from_buffer.reshape(tensor.tensor_shape.dim) | |||||
| print(data_from_buffer.shape) | |||||
| input_data = Tensor(data_from_buffer * 0.01) | |||||
| input_label = Tensor(np.ones([32]).astype(np.int32)) | |||||
| test_lenet(input_data, input_label) | |||||