|
- package rpc
-
- import (
- context "context"
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "sync"
- "time"
-
- grpc "google.golang.org/grpc"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/metadata"
- )
-
- type PoolConfig struct {
- RootCA *x509.CertPool
- // 客户端证书,与AccessTokenProvider二选一
- ClientCert *tls.Certificate
- // AccessTokenProvider,与ClientCert二选一
- AccessTokenProvider AccessTokenProvider
- }
-
- type ConnPool struct {
- cfg PoolConfig
- grpcCons map[string]*grpcCon
- lock sync.Mutex
- }
-
- type grpcCon struct {
- grpcCon *grpc.ClientConn
- refCount int
- stopClosing chan any
- }
-
- func NewConnPool(cfg PoolConfig) *ConnPool {
- return &ConnPool{
- cfg: cfg,
- grpcCons: make(map[string]*grpcCon),
- }
- }
-
- func (p *ConnPool) GetConnection(addr string) (*grpc.ClientConn, error) {
- p.lock.Lock()
- defer p.lock.Unlock()
-
- con := p.grpcCons[addr]
- if con == nil {
- gcon, err := p.connecting(addr)
- if err != nil {
- return nil, err
- }
-
- con = &grpcCon{
- grpcCon: gcon,
- refCount: 0,
- stopClosing: nil,
- }
-
- p.grpcCons[addr] = con
- } else if con.stopClosing != nil {
- close(con.stopClosing)
- con.stopClosing = nil
- }
-
- con.refCount++
-
- return con.grpcCon, nil
- }
-
- func (p *ConnPool) connecting(addr string) (*grpc.ClientConn, error) {
- if p.cfg.ClientCert != nil {
- gcon, err := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
- RootCAs: p.cfg.RootCA,
- Certificates: []tls.Certificate{*p.cfg.ClientCert},
- ServerName: InternalAPISNIV1,
- NextProtos: []string{"h2"},
- })))
- if err != nil {
- return nil, err
- }
-
- return gcon, nil
- }
-
- if p.cfg.AccessTokenProvider == nil {
- return nil, fmt.Errorf("no client cert or access token provider")
- }
-
- gcon, err := grpc.NewClient(addr,
- grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
- RootCAs: p.cfg.RootCA,
- ServerName: ClientAPISNIV1,
- NextProtos: []string{"h2"},
- })),
- grpc.WithUnaryInterceptor(p.populateAccessTokenUnary),
- grpc.WithStreamInterceptor(p.populateAccessTokenStream),
- )
- if err != nil {
- return nil, err
- }
-
- return gcon, nil
- }
- func (p *ConnPool) Release(addr string) {
- p.lock.Lock()
- defer p.lock.Unlock()
-
- grpcCon := p.grpcCons[addr]
- if grpcCon == nil {
- return
- }
-
- grpcCon.refCount--
- grpcCon.refCount = max(grpcCon.refCount, 0)
-
- if grpcCon.refCount == 0 {
- stopClosing := make(chan any)
- grpcCon.stopClosing = stopClosing
-
- go func() {
- select {
- case <-stopClosing:
- return
-
- case <-time.After(time.Minute):
- p.lock.Lock()
- defer p.lock.Unlock()
-
- grpcCon := p.grpcCons[addr]
- if grpcCon == nil {
- return
- }
-
- if grpcCon.refCount == 0 {
- grpcCon.grpcCon.Close()
- delete(p.grpcCons, addr)
- }
- }
- }()
- }
- }
-
- func (p *ConnPool) populateAccessTokenUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
- authInfo, err := p.cfg.AccessTokenProvider.MakeAuthInfo()
- if err != nil {
- return err
- }
-
- md := metadata.Pairs(
- MetaUserID, fmt.Sprintf("%v", authInfo.UserID),
- MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID),
- MetaNonce, authInfo.Nonce,
- MetaSignature, authInfo.Signature,
- )
-
- ctx = metadata.NewOutgoingContext(ctx, md)
- return invoker(ctx, method, req, reply, cc, opts...)
- }
-
- func (p *ConnPool) populateAccessTokenStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
- authInfo, err := p.cfg.AccessTokenProvider.MakeAuthInfo()
- if err != nil {
- return nil, err
- }
-
- md := metadata.Pairs(
- MetaUserID, fmt.Sprintf("%v", authInfo.UserID),
- MetaAccessTokenID, fmt.Sprintf("%v", authInfo.AccessTokenID),
- MetaNonce, authInfo.Nonce,
- MetaSignature, authInfo.Signature,
- )
-
- ctx = metadata.NewOutgoingContext(ctx, md)
- return streamer(ctx, desc, cc, method, opts...)
- }
|