Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions lib/go/thrift/server_socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/


package thrift

import (
Expand All @@ -26,6 +27,8 @@ import (
)

type TServerSocket struct {
// TServerSocketListenerFactory abstracts how listeners are created.
listenerFactory func(net.Addr) (net.Listener, error)
addr net.Addr
clientTimeout time.Duration

Expand All @@ -44,28 +47,62 @@ func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*T
if err != nil {
return nil, err
}
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil

return NewTServerSocketFromAddrTimeout(addr, clientTimeout), nil
}

// NewTServerSocketFromAddrTimeout returns TServerSocket
// Creates a TServerSocket from a net.Addr
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
listenerFactory := func(addr net.Addr) (net.Listener, error) {
return net.Listen(addr.Network(), addr.String())
}

return NewTServerSocketFromFactoryTimeout(listenerFactory, addr, clientTimeout)
}

func (p *TServerSocket) Listen() error {
// NewTServerSocketFromFactoryTimeout creates TServerSocket via a listener factory.
//
// Allows full customization (TLS, mocks, unix sockets, windows named pipes, etc.)
func NewTServerSocketFromFactoryTimeout(listenerFactory func(addr net.Addr) (listener net.Listener, err error), addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{
listenerFactory: listenerFactory,
addr: addr,
clientTimeout: clientTimeout,
}
}

func (p *TServerSocket) try_listen(raise bool) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {

if p.listener != nil {
if (raise) {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
return nil
}
l, err := net.Listen(p.addr.Network(), p.addr.String())

l, err := p.listenerFactory(p.addr)
if err != nil {
return err
}

p.listener = l
p.interrupted = false
return nil
}

// Open does try to listen and return on failure
// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
return p.try_listen(true /* raise error if listening */)
}

func (p *TServerSocket) Listen() error {
return p.try_listen(false /* do not raise error if listening */)
}

func (p *TServerSocket) Accept() (TTransport, error) {
p.mu.RLock()
interrupted := p.interrupted
Expand All @@ -87,51 +124,43 @@ func (p *TServerSocket) Accept() (TTransport, error) {
return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil
}

// IsListening returns listener != nil
// Checks whether the socket is listening.
func (p *TServerSocket) IsListening() bool {
p.mu.RLock()
defer p.mu.RUnlock()
return p.listener != nil
}

// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil {
return err
} else {
p.listener = l
}
return nil
}

func (p *TServerSocket) Addr() net.Addr {
p.mu.RLock()
defer p.mu.RUnlock()
if p.IsListening() {

if p.listener != nil {
return p.listener.Addr()
}
return p.addr
}

func (p *TServerSocket) Close() error {
var err error
func (p *TServerSocket) try_close(interrupt bool) error {
p.mu.Lock()
if p.IsListening() {
defer p.mu.Unlock()
if (interrupt){
p.interrupted = true
}

var err error = nil
if p.listener != nil {
err = p.listener.Close()
p.listener = nil
}
p.mu.Unlock()
return err
}

func (p *TServerSocket) Interrupt() error {
p.mu.Lock()
p.interrupted = true
p.mu.Unlock()
p.Close()
func (p *TServerSocket) Close() error {
return p.try_close(false /* do not set interrupted flag */)
}

return nil
func (p *TServerSocket) Interrupt() error {
return p.try_close(true /* set interrupted flag */)
}
Loading