package rpc import ( "crypto/tls" "fmt" "strconv" jcstypes "gitlink.org.cn/cloudream/jcs-pub/common/types" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) const ( ClientAPISNIV1 = "rpc.client.jcs-pub.v1" InternalAPISNIV1 = "rpc.internal.jcs-pub.v1" MetaUserID = "x-jcs-user-id" MetaAccessTokenID = "x-jcs-access-token-id" MetaNonce = "x-jcs-nonce" MetaSignature = "x-jcs-signature" MetaTokenAuthInfo = "x-jcs-token-auth-info" ) type AccessTokenAuthInfo struct { UserID jcstypes.UserID AccessTokenID jcstypes.AccessTokenID Nonce string Signature string } type AccessTokenVerifier interface { Verify(authInfo AccessTokenAuthInfo) bool } type AccessTokenProvider interface { MakeAuthInfo() (AccessTokenAuthInfo, error) } func (s *ServerBase) tlsConfigSelector(hello *tls.ClientHelloInfo) (*tls.Config, error) { switch hello.ServerName { case ClientAPISNIV1: return &tls.Config{ Certificates: []tls.Certificate{s.serverCert}, ClientAuth: tls.NoClientCert, NextProtos: []string{"h2"}, }, nil case InternalAPISNIV1: return &tls.Config{ Certificates: []tls.Certificate{s.serverCert}, ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: s.rootCA, NextProtos: []string{"h2"}, }, nil default: return nil, fmt.Errorf("unknown server name: %s", hello.ServerName) } } func (s *ServerBase) authUnary( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (resp any, err error) { pr, ok := peer.FromContext(ctx) if !ok { return nil, status.Error(codes.Unauthenticated, "no peer found in context") } tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo) if !ok { return nil, status.Error(codes.Unauthenticated, "no tls info found in peer") } // 如果是使用interanl ServerName通过的TLS认证,则直接放行 if tlsInfo.State.ServerName == InternalAPISNIV1 { return handler(ctx, req) } // 如果是无需认证的API,则直接放行 if s.noAuthAPIs[info.FullMethod] { return handler(ctx, req) } // 否则要进行额外的Token认证 if !s.accessTokenAuthAPIs[info.FullMethod] { return nil, status.Error(codes.Unauthenticated, "unauthorized access") } meta, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.Error(codes.Unauthenticated, "no metadata found in context") } userIDs := meta.Get(MetaUserID) if len(userIDs) != 1 { return nil, status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata") } userID, err := strconv.ParseInt(userIDs[0], 10, 64) if err != nil { return nil, status.Error(codes.Unauthenticated, "invalid user id in metadata") } accessTokenIDs := meta.Get(MetaAccessTokenID) if len(accessTokenIDs) != 1 { return nil, status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata") } nonce := meta.Get(MetaNonce) if len(nonce) != 1 { return nil, status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata") } signature := meta.Get(MetaSignature) if len(signature) != 1 { return nil, status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata") } authInfo := AccessTokenAuthInfo{ UserID: jcstypes.UserID(userID), AccessTokenID: jcstypes.AccessTokenID(accessTokenIDs[0]), Nonce: nonce[0], Signature: signature[0], } if !s.tokenVerifier.Verify(authInfo) { return nil, status.Error(codes.Unauthenticated, "invalid access token") } ctx = context.WithValue(ctx, MetaTokenAuthInfo, authInfo) return handler(ctx, req) } func (s *ServerBase) authStream( srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { pr, ok := peer.FromContext(stream.Context()) if !ok { return status.Error(codes.Unauthenticated, "no peer found in context") } tlsInfo, ok := pr.AuthInfo.(credentials.TLSInfo) if !ok { return status.Error(codes.Unauthenticated, "no tls info found in peer") } // 如果是使用interanl ServerName通过的TLS认证,则直接放行 if tlsInfo.State.ServerName == InternalAPISNIV1 { return handler(srv, stream) } // 如果是无需认证的API,则直接放行 if s.noAuthAPIs[info.FullMethod] { return handler(srv, stream) } // 否则要进行额外的Token认证 if !s.accessTokenAuthAPIs[info.FullMethod] { return status.Error(codes.Unauthenticated, "unauthorized access") } meta, ok := metadata.FromIncomingContext(stream.Context()) if !ok { return status.Error(codes.Unauthenticated, "no metadata found in context") } userIDs := meta.Get(MetaUserID) if len(userIDs) != 1 { return status.Error(codes.Unauthenticated, "missing or multiple user ids in metadata") } userID, err := strconv.ParseInt(userIDs[0], 10, 64) if err != nil { return status.Error(codes.Unauthenticated, "invalid user id in metadata") } accessTokenIDs := meta.Get(MetaAccessTokenID) if len(accessTokenIDs) != 1 { return status.Error(codes.Unauthenticated, "missing or multiple access token ids in metadata") } nonce := meta.Get(MetaNonce) if len(nonce) != 1 { return status.Error(codes.Unauthenticated, "missing or multiple nonces in metadata") } signature := meta.Get(MetaSignature) if len(signature) != 1 { return status.Error(codes.Unauthenticated, "missing or multiple signatures in metadata") } authInfo := AccessTokenAuthInfo{ UserID: jcstypes.UserID(userID), AccessTokenID: jcstypes.AccessTokenID(accessTokenIDs[0]), Nonce: nonce[0], Signature: signature[0], } if !s.tokenVerifier.Verify(authInfo) { return status.Error(codes.Unauthenticated, "invalid access token") } return handler(srv, &serverStream{stream, context.WithValue(stream.Context(), MetaTokenAuthInfo, authInfo)}) } type serverStream struct { grpc.ServerStream ctx context.Context } func (s *serverStream) Context() context.Context { return s.ctx } func GetAuthInfo(ctx context.Context) (AccessTokenAuthInfo, bool) { val := ctx.Value(MetaTokenAuthInfo) if val == nil { return AccessTokenAuthInfo{}, false } authInfo, ok := val.(AccessTokenAuthInfo) return authInfo, ok }