diff --git a/go.mod b/go.mod index c82093e..95d48a6 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.0 toolchain go1.24.4 require ( + github.com/pion/stun/v3 v3.0.1 github.com/pion/turn/v4 v4.1.3 github.com/pion/webrtc/v4 v4.1.6 github.com/prometheus/client_golang v1.23.2 @@ -29,7 +30,6 @@ require ( github.com/pion/sctp v1.8.40 // indirect github.com/pion/sdp/v3 v3.0.16 // indirect github.com/pion/srtp/v3 v3.0.8 // indirect - github.com/pion/stun/v3 v3.0.1 // indirect github.com/pion/transport/v3 v3.1.1 // indirect github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/net v0.45.0 // indirect diff --git a/internal/ice/network_discover.go b/internal/ice/network_discover.go new file mode 100644 index 0000000..e9d3651 --- /dev/null +++ b/internal/ice/network_discover.go @@ -0,0 +1,373 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "errors" + "fmt" + "log/slog" + "net" + "time" + + "github.com/pion/stun/v3" +) + +const ( + defaultStunAddr = "stun.voipgate.com:3478" + defaultTimeout = 1 * time.Second + bufferSize = 1024 +) + +type stunServerConn struct { + conn net.PacketConn + LocalAddr net.Addr + RemoteAddr *net.UDPAddr + OtherAddr *net.UDPAddr + messageChan chan *stun.Message +} + +type ( + natBehavior int +) + +const ( + NoNAT natBehavior = iota + EpIndependent + AddrDependent + AddrPrtDependent +) + +var ( + errConnectStun = errors.New("cannot connect to stun server") + errRoundTrip = errors.New("cannot make round trip to stun server") + errDiscoverMapping = errors.New("error discovering nat mapping") + errDiscoverFiltering = errors.New("error discovering nat filtering") + errDiscoverLocal = errors.New("cannot discover local IP") + errNoLocalIPFound = errors.New("no valid local ip address found") + errResponseMessage = errors.New("error reading from response message channel") + errTimedOut = errors.New("timed out waiting for response") + errNoOtherAddress = errors.New("no OTHER-ADDRESS in message") +) + +func (c *stunServerConn) Close() error { + return c.conn.Close() +} + +// Utility functions for NAT detection, local network discovery. + +// DiscoverLocalIP returns IP address for local interface. +func DiscoverLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", fmt.Errorf("%w: %w", errDiscoverLocal, err) + } + for _, address := range addrs { + // check the address type and if it is not a loopback then display it + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + } + + return "", errNoLocalIPFound +} + +// DiscoverNatMapping determines NAT mapping under RFC5780: 4.3. +// Adapted from pion/stun. +func DiscoverNatMapping(stunAddr string, log *slog.Logger) (natBehavior, error) { //nolint:cyclop + if stunAddr == "" { + stunAddr = defaultStunAddr + } + log.Info("Discovering NAT mapping", "stunAddr", stunAddr) + + mapTestConn, err := connect(stunAddr, log) + defer func() { + if mapTestConn != nil { + if cerr := mapTestConn.Close(); cerr != nil { + log.Warn(cerr.Error()) + } + } + }() + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, err) + } + + localAddr, err := DiscoverLocalIP() + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, err) + } + + // Test I: Regular binding request + request := stun.MustBuild(stun.TransactionID, stun.BindingRequest) + + resp, err := mapTestConn.roundTrip(request, mapTestConn.RemoteAddr, log) + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, err) + } + + // Parse response message for XOR-MAPPED-ADDRESS and make sure OTHER-ADDRESS valid + resps1 := parse(resp, log) + if resps1.xorAddr == nil || resps1.otherAddr == nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, errNoOtherAddress) + } + addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String()) + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, err) + } + mapTestConn.OtherAddr = addr + log.Debug("", "Received XOR-MAPPED-ADDRESS", resps1.xorAddr) + + // Assert mapping behavior + if resps1.xorAddr.IP.String() == localAddr { + log.Info("NAT mapping behavior: endpoint independent (no NAT)") + + return NoNAT, nil + } + + // Test II: Send binding request to the other address but primary port + log.Info("Mapping Test II: Send binding request to the other address but primary port") + oaddr := *mapTestConn.OtherAddr + oaddr.Port = mapTestConn.RemoteAddr.Port + resp, err = mapTestConn.roundTrip(request, &oaddr, log) + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, err) + } + + // Assert mapping behavior + resps2 := parse(resp, log) + log.Debug("", "Received XOR-MAPPED-ADDRESS", resps2.xorAddr) + if resps2.xorAddr.String() == resps1.xorAddr.String() { + log.Info("NAT mapping behavior: endpoint independent") + + return EpIndependent, nil + } + + // Test III: Send binding request to the other address and port + log.Debug("Mapping Test III: Send binding request to the other address and port") + resp, err = mapTestConn.roundTrip(request, mapTestConn.OtherAddr, log) + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverMapping, err) + } + + // Assert mapping behavior + resps3 := parse(resp, log) + log.Debug("", "Received XOR-MAPPED-ADDRESS", resps3.xorAddr) + if resps3.xorAddr.String() == resps2.xorAddr.String() { + log.Warn("NAT mapping behavior: address dependent") + + return AddrDependent, nil + } else { + log.Warn("NAT mapping behavior: address and port dependent") + + return AddrPrtDependent, nil + } +} + +// DiscoverNatFiltering determines NAT filtering behavior under RFC5780: 4.4. +// Adapted from pion/stun. +func DiscoverNatFiltering(stunAddr string, log *slog.Logger) (natBehavior, error) { //nolint:cyclop + if stunAddr == "" { + stunAddr = defaultStunAddr + } + log.Info("Discovering NAT filtering", "stunAddr", stunAddr) + + mapTestConn, err := connect(stunAddr, log) + defer func() { + if mapTestConn != nil { + if cerr := mapTestConn.Close(); cerr != nil { + log.Warn(cerr.Error()) + } + } + }() + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverFiltering, err) + } + + // Test I: Regular binding request + log.Info("Filtering Test I: Regular binding request") + request := stun.MustBuild(stun.TransactionID, stun.BindingRequest) + + resp, err := mapTestConn.roundTrip(request, mapTestConn.RemoteAddr, log) + if err != nil || errors.Is(err, errTimedOut) { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverFiltering, err) + } + resps := parse(resp, log) + if resps.xorAddr == nil || resps.otherAddr == nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverFiltering, errNoOtherAddress) + } + addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String()) + if err != nil { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverFiltering, err) + } + mapTestConn.OtherAddr = addr + + // Test II: Request to change both IP and port + log.Info("Filtering Test II: Request to change both IP and port") + request = stun.MustBuild(stun.TransactionID, stun.BindingRequest) + request.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x06}) + + resp, err = mapTestConn.roundTrip(request, mapTestConn.RemoteAddr, log) + if err == nil { + parse(resp, log) // just to print out the resp + log.Info("NAT filtering behavior: endpoint independent") + + return EpIndependent, nil + } else if !errors.Is(err, errTimedOut) { + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverFiltering, err) // something else went wrong + } + + // Test III: Request to change port only + log.Info("Filtering Test III: Request to change port only") + request = stun.MustBuild(stun.TransactionID, stun.BindingRequest) + request.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x02}) + + resp, err = mapTestConn.roundTrip(request, mapTestConn.RemoteAddr, log) + switch { + case err == nil: + { + parse(resp, log) + log.Warn("=> NAT filtering behavior: address dependent") + + return AddrDependent, nil + } + case errors.Is(err, errTimedOut): + { + log.Warn("=> NAT filtering behavior: address and port dependent") + + return AddrPrtDependent, nil + } + default: + return AddrPrtDependent, fmt.Errorf("%w: %w", errDiscoverFiltering, err) + } +} + +// Given an address string, returns a StunServerConn. +func connect(addrStr string, log *slog.Logger) (*stunServerConn, error) { + log.Debug("Connecting to STUN server", "server", addrStr) + addr, err := net.ResolveUDPAddr("udp4", addrStr) + if err != nil { + return nil, fmt.Errorf("%w: addr: %v,%w", errConnectStun, addrStr, err) + } + + c, err := net.ListenUDP("udp4", nil) + if err != nil { + return nil, fmt.Errorf("%w: addr: %v, %w", errConnectStun, addrStr, err) + } + log.Debug("", "Local address", c.LocalAddr()) + log.Debug("", "Remote address", addr.String()) + + mChan := listen(c, log) + + return &stunServerConn{ + conn: c, + LocalAddr: c.LocalAddr(), + RemoteAddr: addr, + messageChan: mChan, + }, nil +} + +// Send request and wait for response or timeout. +// Adapted from pion/stun. +func (c *stunServerConn) roundTrip(msg *stun.Message, addr net.Addr, log *slog.Logger) (*stun.Message, error) { + _ = msg.NewTransactionID() + log.Debug(msg.String()) + _, err := c.conn.WriteTo(msg.Raw, addr) + if err != nil { + return nil, fmt.Errorf("%w: addr: %v, %w", errRoundTrip, addr.String(), err) + } + + // Wait for response or timeout + select { + case m, ok := <-c.messageChan: + if !ok { + return nil, errResponseMessage + } + + return m, nil + case <-time.After(defaultTimeout): + log.Error("Timed out waiting for response from server", "address", addr) + + return nil, fmt.Errorf("%w: addr:%v, %w", errRoundTrip, addr, errTimedOut) + } +} + +// taken from https://github.com/pion/stun/blob/master/cmd/stun-traversal/main.go +func listen(conn *net.UDPConn, log *slog.Logger) (messages chan *stun.Message) { + messages = make(chan *stun.Message) + go func() { + for { + buf := make([]byte, bufferSize) + + n, _, err := conn.ReadFromUDP(buf) + if err != nil { + close(messages) + + return + } + buf = buf[:n] + + m := new(stun.Message) + m.Raw = buf + err = m.Decode() + if err != nil { + log.Error(err.Error()) + close(messages) + + return + } + + messages <- m + } + }() + + return +} + +// Parse a STUN message. +// Adapted from pion/stun. +func parse(msg *stun.Message, log *slog.Logger) (ret struct { + xorAddr *stun.XORMappedAddress + otherAddr *stun.OtherAddress + respOrigin *stun.ResponseOrigin + mappedAddr *stun.MappedAddress + software *stun.Software +}, +) { + ret.mappedAddr = &stun.MappedAddress{} + ret.xorAddr = &stun.XORMappedAddress{} + ret.respOrigin = &stun.ResponseOrigin{} + ret.otherAddr = &stun.OtherAddress{} + ret.software = &stun.Software{} + if ret.xorAddr.GetFrom(msg) != nil { + ret.xorAddr = nil + } + if ret.otherAddr.GetFrom(msg) != nil { + ret.otherAddr = nil + } + if ret.respOrigin.GetFrom(msg) != nil { + ret.respOrigin = nil + } + if ret.mappedAddr.GetFrom(msg) != nil { + ret.mappedAddr = nil + } + if ret.software.GetFrom(msg) != nil { + ret.software = nil + } + for _, attr := range msg.Attributes { + switch attr.Type { + case + stun.AttrXORMappedAddress, + stun.AttrOtherAddress, + stun.AttrResponseOrigin, + stun.AttrMappedAddress, + stun.AttrSoftware: + break //nolint:staticcheck + default: + log.Debug(fmt.Sprintf("\t%v (l=%v)", attr, attr.Length)) + } + } + + return ret +} diff --git a/internal/ice/network_discover_test.go b/internal/ice/network_discover_test.go new file mode 100644 index 0000000..f5c4ac2 --- /dev/null +++ b/internal/ice/network_discover_test.go @@ -0,0 +1,659 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package ice + +import ( + "errors" + "io" + "log/slog" + "net" + "sync" + "testing" + "time" + + "github.com/pion/stun/v3" + "github.com/stretchr/testify/require" +) + +// newTestLogger returns a slog.Logger that discards all output. +func newTestLogger(tb testing.TB) *slog.Logger { + tb.Helper() + + return slog.New( + slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }), + ) +} + +func TestDiscoverLocalIP_BestEffort(t *testing.T) { + // First check if this machine even has a non-loopback IPv4 address. + addrs, err := net.InterfaceAddrs() + require.NoError(t, err) + + hasNonLoopbackIPv4 := false + + for _, a := range addrs { + if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { + hasNonLoopbackIPv4 = true + + break + } + } + + if !hasNonLoopbackIPv4 { + t.Skip("no non-loopback IPv4 addresses on this host; skipping DiscoverLocalIP happy-path test") + } + + ipStr, err := DiscoverLocalIP() + require.NoError(t, err) + + ip := net.ParseIP(ipStr) + require.NotNil(t, ip, "DiscoverLocalIP should return a parseable IP") + require.NotNil(t, ip.To4(), "DiscoverLocalIP should return an IPv4 address") +} + +func TestConnect_Success(t *testing.T) { + log := newTestLogger(t) + + conn, err := connect("127.0.0.1:0", log) + require.NoError(t, err) + require.NotNil(t, conn) + + defer func() { + if cerr := conn.Close(); cerr != nil { + t.Logf("failed to close stunServerConn: %v", cerr) + } + }() + + require.NotNil(t, conn.conn) + require.NotNil(t, conn.LocalAddr) + require.NotNil(t, conn.RemoteAddr) + require.NotNil(t, conn.messageChan) +} + +func TestConnect_InvalidAddress(t *testing.T) { + log := newTestLogger(t) + + conn, err := connect("this-is-not-a-valid-host:9999", log) + require.Error(t, err) + require.Nil(t, conn) + require.ErrorIs(t, err, errConnectStun) +} + +func TestStunServerConn_Close(t *testing.T) { + pc, err := net.ListenPacket("udp4", "127.0.0.1:0") //nolint:noctx + require.NoError(t, err) + + defer func() { + if cerr := pc.Close(); cerr != nil { + t.Logf("failed to close PacketConn: %v", cerr) + } + }() + + s := &stunServerConn{ + conn: pc, + LocalAddr: pc.LocalAddr(), + } + + err = s.Close() + require.NoError(t, err) +} + +func TestRoundTrip_Success(t *testing.T) { + log := newTestLogger(t) + + pc, err := net.ListenPacket("udp4", "127.0.0.1:0") //nolint:noctx + require.NoError(t, err) + + defer func() { + if cerr := pc.Close(); cerr != nil { + t.Logf("failed to close PacketConn: %v", cerr) + } + }() + + remote := pc.LocalAddr() + + srv := &stunServerConn{ //nolint:forcetypeassert + conn: pc, + LocalAddr: pc.LocalAddr(), + RemoteAddr: remote.(*net.UDPAddr), + messageChan: make(chan *stun.Message, 1), + } + + respMsg := stun.MustBuild(stun.TransactionID, stun.BindingSuccess) + + go func() { + srv.messageChan <- respMsg + }() + + req := stun.MustBuild(stun.TransactionID, stun.BindingRequest) + + got, err := srv.roundTrip(req, srv.RemoteAddr, log) + require.NoError(t, err) + require.NotNil(t, got) +} + +func TestRoundTrip_WriteError(t *testing.T) { + log := newTestLogger(t) + + pc, err := net.ListenPacket("udp4", "127.0.0.1:0") //nolint:noctx + require.NoError(t, err) + + if cerr := pc.Close(); cerr != nil { + t.Logf("failed to close PacketConn: %v", cerr) + } + + s := &stunServerConn{ + conn: pc, + messageChan: make(chan *stun.Message), + } + + req := stun.MustBuild(stun.TransactionID, stun.BindingRequest) + + got, err := s.roundTrip(req, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9999}, log) + require.Nil(t, got) + require.Error(t, err) + require.ErrorIs(t, err, errRoundTrip) +} + +func TestRoundTrip_ChannelClosed(t *testing.T) { + log := newTestLogger(t) + + pc, err := net.ListenPacket("udp4", "127.0.0.1:0") //nolint:noctx + require.NoError(t, err) + + defer func() { + if cerr := pc.Close(); cerr != nil { + t.Logf("failed to close PacketConn: %v", cerr) + } + }() + + ch := make(chan *stun.Message) + close(ch) + + s := &stunServerConn{ //nolint:forcetypeassert + conn: pc, + LocalAddr: pc.LocalAddr(), + RemoteAddr: pc.LocalAddr().(*net.UDPAddr), + messageChan: ch, + } + + req := stun.MustBuild(stun.TransactionID, stun.BindingRequest) + + got, err := s.roundTrip(req, s.RemoteAddr, log) + require.Nil(t, got) + require.ErrorIs(t, err, errResponseMessage) +} + +func TestListen_ClosesChannelOnReadError(t *testing.T) { + log := newTestLogger(t) + + udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + + msgCh := listen(udpConn, log) + + if cerr := udpConn.Close(); cerr != nil { + t.Logf("failed to close UDPConn: %v", cerr) + } + + if _, ok := <-msgCh; ok { + require.True(t, ok) + } +} + +func TestParse_ReturnsAttributesWhenPresent(t *testing.T) { + log := newTestLogger(t) + + var msg stun.Message + msg.Type = stun.BindingSuccess + msg.TransactionID = [stun.TransactionIDSize]byte{1, 2, 3, 4} + + xorAddr := &stun.XORMappedAddress{ + IP: net.IPv4(192, 0, 2, 1), + Port: 3478, + } + require.NoError(t, xorAddr.AddTo(&msg)) + + otherAddr := &stun.OtherAddress{ + IP: net.IPv4(192, 0, 2, 2), + Port: 3479, + } + require.NoError(t, otherAddr.AddTo(&msg)) + + respOrigin := &stun.ResponseOrigin{ + IP: net.IPv4(192, 0, 2, 3), + Port: 3480, + } + require.NoError(t, respOrigin.AddTo(&msg)) + + mappedAddr := &stun.MappedAddress{ + IP: net.IPv4(192, 0, 2, 4), + Port: 3481, + } + require.NoError(t, mappedAddr.AddTo(&msg)) + + software := &stun.Software{} + require.NoError(t, software.AddTo(&msg)) + + ret := parse(&msg, log) + + require.NotNil(t, ret.xorAddr, "xorAddr should be non-nil when XOR-MAPPED-ADDRESS is present") + require.NotNil(t, ret.otherAddr, "otherAddr should be non-nil when OTHER-ADDRESS is present") + require.NotNil(t, ret.respOrigin, "respOrigin should be non-nil when RESPONSE-ORIGIN is present") + require.NotNil(t, ret.mappedAddr, "mappedAddr should be non-nil when MAPPED-ADDRESS is present") + require.NotNil(t, ret.software, "software should be non-nil when SOFTWARE is present") +} + +func TestParse_AllowsMissingAttributes(t *testing.T) { + log := newTestLogger(t) + + msg := stun.MustBuild(stun.TransactionID, stun.BindingSuccess) + + ret := parse(msg, log) + + require.Nil(t, ret.xorAddr) + require.Nil(t, ret.otherAddr) + require.Nil(t, ret.respOrigin) + require.Nil(t, ret.mappedAddr) + require.Nil(t, ret.software) +} + +func TestDiscoverNatMapping_ConnectError(t *testing.T) { + log := newTestLogger(t) + + _, err := DiscoverNatMapping("invalid-address", log) + require.Error(t, err) + require.ErrorIs(t, err, errDiscoverMapping) +} + +type testStunServer struct { + conn *net.UDPConn + handler func(call int, req *stun.Message) (*stun.Message, bool) + callNum int + mu sync.Mutex +} + +func newTestStunServer(t *testing.T) *testStunServer { + t.Helper() + + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + conn, err := net.ListenUDP("udp4", addr) + require.NoError(t, err) + + srv := &testStunServer{ + conn: conn, + } + + go srv.serve(t) + + t.Cleanup(func() { + if cerr := srv.conn.Close(); cerr != nil { + t.Logf("failed to close test STUN server: %v", cerr) + } + }) + + return srv +} + +func (s *testStunServer) Addr() string { + return s.conn.LocalAddr().String() +} + +func (s *testStunServer) serve(t *testing.T) { + t.Helper() + buf := make([]byte, 1500) + + for { + n, addr, err := s.conn.ReadFromUDP(buf) + if err != nil { + return + } + + msg := &stun.Message{Raw: append([]byte(nil), buf[:n]...)} + if decodeErr := msg.Decode(); err != nil { + t.Logf("failed to decode STUN request: %v", decodeErr) + + continue + } + + s.mu.Lock() + s.callNum++ + call := s.callNum + s.mu.Unlock() + + resp, ok := s.handler(call, msg) + if !ok || resp == nil { + continue + } + + _, err = s.conn.WriteTo(resp.Raw, addr) + if err != nil { + t.Logf("failed to write STUN response: %v", err) + } + } +} + +func (s *testStunServer) SetHandler(h func(call int, req *stun.Message) (*stun.Message, bool)) { + s.mu.Lock() + defer s.mu.Unlock() + s.handler = h +} + +func stunResp(attrs ...stun.Setter) *stun.Message { + return stun.MustBuild( + attrs..., + ) +} + +/* ------------------------------------------------------------- + NAT MAPPING TESTS + ------------------------------------------------------------- */ + +func TestDiscoverNatMapping_NoNAT(t *testing.T) { + log := newTestLogger(t) + + localIP, err := DiscoverLocalIP() + if errors.Is(err, errNoLocalIPFound) { + t.Skip() + } + require.NoError(t, err) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + if call != 1 { + return nil, false + } + + //nolint:forcetypeassert + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP(localIP), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srv.conn.LocalAddr().(*net.UDPAddr).Port}, + ), true + }) + + behavior, derr := DiscoverNatMapping(srv.Addr(), log) + require.NoError(t, derr) + require.Equal(t, NoNAT, behavior) +} + +func TestDiscoverNatMapping_EpIndependent(t *testing.T) { + log := newTestLogger(t) + + _, err := DiscoverLocalIP() + if errors.Is(err, errNoLocalIPFound) { + t.Skip() + } + require.NoError(t, err) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + switch call { + case 1: + //nolint:forcetypeassert + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srv.conn.LocalAddr().(*net.UDPAddr).Port}, + ), true + case 2: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + ), true + default: + return nil, false + } + }) + + behavior, derr := DiscoverNatMapping(srv.Addr(), log) + require.NoError(t, derr) + require.Equal(t, EpIndependent, behavior) +} + +func TestDiscoverNatMapping_Dependent(t *testing.T) { + log := newTestLogger(t) + + _, err := DiscoverLocalIP() + require.NoError(t, err) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + srvPort := srv.conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + switch call { + case 1: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.1"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srvPort}, + ), true + + case 2: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.2"), Port: 5001}, + ), true + + case 3: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.2"), Port: 5001}, + ), true + case 4: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.1"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srvPort}, + ), true + + case 5: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.2"), Port: 5001}, + ), true + + case 6: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.2"), Port: 5002}, + ), true + + default: + return nil, false + } + }) + + behavior, err2 := DiscoverNatMapping(srv.Addr(), log) + require.NoError(t, err2) + require.Equal(t, AddrDependent, behavior) + behavior, err2 = DiscoverNatMapping(srv.Addr(), log) + require.NoError(t, err2) + require.Equal(t, AddrPrtDependent, behavior) +} + +func TestDiscoverNatMapping_MissingOtherAddress(t *testing.T) { + log := newTestLogger(t) + + _, err := DiscoverLocalIP() + require.NoError(t, err) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.5"), Port: 5000}, + ), true + }) + + _, derr := DiscoverNatMapping(srv.Addr(), log) + require.Error(t, derr) + require.ErrorIs(t, derr, errDiscoverMapping) + require.ErrorIs(t, derr, errNoOtherAddress) +} + +func TestDiscoverNatMapping_Timeout(t *testing.T) { + log := newTestLogger(t) + + _, err := DiscoverLocalIP() + require.NoError(t, err) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + return nil, false // cause timeout + }) + + start := time.Now() + behavior, derr := DiscoverNatMapping(srv.Addr(), log) + elapsed := time.Since(start) + + require.Equal(t, AddrPrtDependent, behavior) + require.Error(t, derr) + require.ErrorIs(t, derr, errTimedOut) + require.GreaterOrEqual(t, elapsed, defaultTimeout) +} + +// /* ------------------------------------------------------------- +// NAT FILTERING TESTS +// ------------------------------------------------------------- */ + +func TestDiscoverNatFiltering_ConnectError(t *testing.T) { + log := newTestLogger(t) + + behavior, err := DiscoverNatFiltering("%invalid", log) + require.Equal(t, AddrPrtDependent, behavior) + require.Error(t, err) + require.ErrorIs(t, err, errDiscoverFiltering) +} + +func TestDiscoverNatFiltering_MissingOtherAddress(t *testing.T) { + log := newTestLogger(t) + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + switch call { + case 1: + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + ), true + default: + return nil, false + } + }) + + _, err := DiscoverNatFiltering(srv.Addr(), log) + require.Error(t, err) + require.ErrorIs(t, err, errDiscoverFiltering) + require.ErrorIs(t, err, errNoOtherAddress) +} + +func TestDiscoverNatFiltering_EpIndependent(t *testing.T) { + log := newTestLogger(t) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + switch call { + case 1: + //nolint:forcetypeassert + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srv.conn.LocalAddr().(*net.UDPAddr).Port}, + ), true + case 2: + // Test II: respond → EpIndependent filtering + return stunResp(), true + default: + return nil, false + } + }) + + behavior, err := DiscoverNatFiltering(srv.Addr(), log) + require.NoError(t, err) + require.Equal(t, EpIndependent, behavior) +} + +func TestDiscoverNatFiltering_AddrDependent(t *testing.T) { + log := newTestLogger(t) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + switch call { + case 1: + //nolint:forcetypeassert + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srv.conn.LocalAddr().(*net.UDPAddr).Port}, + ), true + + case 2: + return nil, false // timeout + + case 3: + // Test III: respond → address dependent + return stunResp(), true + + default: + return nil, false + } + }) + + start := time.Now() + behavior, err := DiscoverNatFiltering(srv.Addr(), log) + elapsed := time.Since(start) + + require.NoError(t, err) + require.Equal(t, AddrDependent, behavior) + require.GreaterOrEqual(t, elapsed, defaultTimeout) +} + +func TestDiscoverNatFiltering_AddrEpDependent(t *testing.T) { + log := newTestLogger(t) + + srv := newTestStunServer(t) + srv.SetHandler( + func(call int, req *stun.Message) (*stun.Message, bool) { + switch call { + case 1: + //nolint:forcetypeassert + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srv.conn.LocalAddr().(*net.UDPAddr).Port}, + ), true + + case 2: + return nil, false // timeout Test II + case 3: + return nil, false // timeout Test III → AddrEpDependent + + default: + return nil, false + } + }) + + start := time.Now() + behavior, err := DiscoverNatFiltering(srv.Addr(), log) + elapsed := time.Since(start) + + require.NoError(t, err) + require.Equal(t, AddrPrtDependent, behavior) + require.GreaterOrEqual(t, elapsed, 2*defaultTimeout) +} + +func TestDiscoverNatFiltering_TestII_NonTimeoutErr(t *testing.T) { + log := newTestLogger(t) + + srv := newTestStunServer(t) + srv.SetHandler(func(call int, req *stun.Message) (*stun.Message, bool) { + if call == 1 { + //nolint:forcetypeassert + return stunResp( + &stun.XORMappedAddress{IP: net.ParseIP("203.0.113.10"), Port: 5000}, + &stun.OtherAddress{IP: net.ParseIP("127.0.0.1"), Port: srv.conn.LocalAddr().(*net.UDPAddr).Port}), true + } + if call == 2 { + resp := &stun.Message{Raw: []byte("not-a-stun")} + + return resp, true + } + + return nil, false + }) + + _, err := DiscoverNatFiltering(srv.Addr(), log) + require.Error(t, err) + require.ErrorIs(t, err, errDiscoverFiltering) +}