From 99f0aaf7e2620204d78cd9dd1f69fff7db7d4985 Mon Sep 17 00:00:00 2001 From: qr243vbi Date: Fri, 24 Apr 2026 01:49:26 +0800 Subject: [PATCH] Update server_socket.go --- lib/go/thrift/server_socket.go | 91 ++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/lib/go/thrift/server_socket.go b/lib/go/thrift/server_socket.go index 164221e92b2..798b7a50070 100644 --- a/lib/go/thrift/server_socket.go +++ b/lib/go/thrift/server_socket.go @@ -17,6 +17,7 @@ * under the License. */ + package thrift import ( @@ -26,6 +27,8 @@ import ( ) type TServerSocket struct { + // TServerSocketListenerFactory abstracts how listeners are created. + listenerFactory func(net.Addr) (net.Listener, error) addr net.Addr clientTimeout time.Duration @@ -44,28 +47,62 @@ func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*T if err != nil { return nil, err } - return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil + + return NewTServerSocketFromAddrTimeout(addr, clientTimeout), nil } +// NewTServerSocketFromAddrTimeout returns TServerSocket // Creates a TServerSocket from a net.Addr func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket { - return &TServerSocket{addr: addr, clientTimeout: clientTimeout} + listenerFactory := func(addr net.Addr) (net.Listener, error) { + return net.Listen(addr.Network(), addr.String()) + } + + return NewTServerSocketFromFactoryTimeout(listenerFactory, addr, clientTimeout) } -func (p *TServerSocket) Listen() error { +// NewTServerSocketFromFactoryTimeout creates TServerSocket via a listener factory. +// +// Allows full customization (TLS, mocks, unix sockets, windows named pipes, etc.) +func NewTServerSocketFromFactoryTimeout(listenerFactory func(addr net.Addr) (listener net.Listener, err error), addr net.Addr, clientTimeout time.Duration) *TServerSocket { + return &TServerSocket{ + listenerFactory: listenerFactory, + addr: addr, + clientTimeout: clientTimeout, + } +} + +func (p *TServerSocket) try_listen(raise bool) error { p.mu.Lock() defer p.mu.Unlock() - if p.IsListening() { + + if p.listener != nil { + if (raise) { + return NewTTransportException(ALREADY_OPEN, "Server socket already open") + } return nil } - l, err := net.Listen(p.addr.Network(), p.addr.String()) + + l, err := p.listenerFactory(p.addr) if err != nil { return err } + p.listener = l + p.interrupted = false return nil } +// Open does try to listen and return on failure +// Connects the socket, creating a new socket object if necessary. +func (p *TServerSocket) Open() error { + return p.try_listen(true /* raise error if listening */) +} + +func (p *TServerSocket) Listen() error { + return p.try_listen(false /* do not raise error if listening */) +} + func (p *TServerSocket) Accept() (TTransport, error) { p.mu.RLock() interrupted := p.interrupted @@ -87,51 +124,43 @@ func (p *TServerSocket) Accept() (TTransport, error) { return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil } +// IsListening returns listener != nil // Checks whether the socket is listening. func (p *TServerSocket) IsListening() bool { + p.mu.RLock() + defer p.mu.RUnlock() return p.listener != nil } -// Connects the socket, creating a new socket object if necessary. -func (p *TServerSocket) Open() error { - p.mu.Lock() - defer p.mu.Unlock() - if p.IsListening() { - return NewTTransportException(ALREADY_OPEN, "Server socket already open") - } - if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil { - return err - } else { - p.listener = l - } - return nil -} - func (p *TServerSocket) Addr() net.Addr { p.mu.RLock() defer p.mu.RUnlock() - if p.IsListening() { + + if p.listener != nil { return p.listener.Addr() } return p.addr } -func (p *TServerSocket) Close() error { - var err error +func (p *TServerSocket) try_close(interrupt bool) error { p.mu.Lock() - if p.IsListening() { + defer p.mu.Unlock() + if (interrupt){ + p.interrupted = true + } + + var err error = nil + if p.listener != nil { err = p.listener.Close() p.listener = nil } - p.mu.Unlock() return err } -func (p *TServerSocket) Interrupt() error { - p.mu.Lock() - p.interrupted = true - p.mu.Unlock() - p.Close() +func (p *TServerSocket) Close() error { + return p.try_close(false /* do not set interrupted flag */) +} - return nil +func (p *TServerSocket) Interrupt() error { + return p.try_close(true /* set interrupted flag */) }