From 251638f59b9f98dfc8b44147e7dff15bf697ad68 Mon Sep 17 00:00:00 2001 From: Tony Blyler Date: Mon, 16 Jul 2018 17:13:30 -0400 Subject: [PATCH] Initial commit --- local/broadcast.go | 189 ++++++++++++++++++++++++++++ local/broadcast_test.go | 266 ++++++++++++++++++++++++++++++++++++++++ seeker.go | 9 ++ 3 files changed, 464 insertions(+) create mode 100644 local/broadcast.go create mode 100644 local/broadcast_test.go create mode 100644 seeker.go diff --git a/local/broadcast.go b/local/broadcast.go new file mode 100644 index 0000000..b0bd623 --- /dev/null +++ b/local/broadcast.go @@ -0,0 +1,189 @@ +package broadcast + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" +) + +const ( + // MaxUDPPacketSize is the max packet size for a broadcast message + // 0xffff - len(IPHeader) - len(UDPHeader) + MaxUDPPacketSize = 0xffff - 20 - 8 + + // DefaultBroadcastListenDeadline is the number of seconds to spend waiting for a packet + // before a loop and check of the context doneness + DefaultBroadcastListenDeadline = time.Second * time.Duration(5) + + // DefaultBroadcastSendDeadline is the number of seconds to spend sending a packet before + // giving up for a given broadcast address + DefaultBroadcastSendDeadline = time.Second * time.Duration(5) +) + +var ( + // ErrNotIPv4 if a non-IPv4 address was given when an IPv4 address was expected + ErrNotIPv4 = errors.New("non IPv4 address given") + + // ErrZeroPort if a zero port was defined for an IPv4 address when one must be specified + ErrZeroPort = errors.New("missing port definition in address") + + // ErrBroadcastSendBadSize if the message size of a message to send is too large + ErrBroadcastSendBadSize = fmt.Errorf("broadcast message size must be > 0B and <= %dB", MaxUDPPacketSize) +) + +// Broadcast implements Seeker with local network broadcasting with IPv4 broadcasting +type Broadcast struct { + listenAddr *net.UDPAddr + bcastIPs []net.IP + lconn *net.UDPConn + lconnLock sync.Mutex + + ListenDeadline time.Duration + SendDeadline time.Duration +} + +// NewBroadcast creates a new instance of Broadcast +func NewBroadcast(listenAddr *net.UDPAddr, bcastIPs []net.IP) (b *Broadcast, err error) { + err = validateBcastListenAddr(listenAddr) + if err != nil { + return + } + + if bcastIPs == nil { + bcastIPs = DefaultBroadcastIPs() + } + + b = &Broadcast{ + listenAddr: listenAddr, + bcastIPs: bcastIPs, + ListenDeadline: DefaultBroadcastListenDeadline, + SendDeadline: DefaultBroadcastSendDeadline, + } + + return +} + +// Listen for incoming Send requests from other Broadcast implementations +func (b *Broadcast) Listen(ctx context.Context, msgChan chan<- []byte) error { + b.lconnLock.Lock() + defer b.lconnLock.Unlock() + + var err error + b.lconn, err = net.ListenUDP("udp4", b.listenAddr) + if err != nil { + return err + } + defer b.lconn.Close() + + for { + select { + case <-ctx.Done(): + return nil + + default: + // noop + } + + packet := make([]byte, MaxUDPPacketSize) + b.lconn.SetReadDeadline(time.Now().Add(b.ListenDeadline)) + dataLen, _, err := b.lconn.ReadFrom(packet) + if err != nil { + // only return err if it was not a timeout error + if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return err + } + + continue + } + + msgChan <- packet[:dataLen] + } +} + +// Send a message to other potential Broadcast implementations +func (b *Broadcast) Send(msg []byte) error { + msgLen := len(msg) + if msgLen > MaxUDPPacketSize || msgLen <= 0 { + return ErrBroadcastSendBadSize + } + + conn, err := net.ListenUDP("udp4", nil) + if err != nil { + return err + } + defer conn.Close() + + for _, ip := range b.bcastIPs { + dst := &net.UDPAddr{IP: ip, Port: b.listenAddr.Port} + // be vigilant about short sends + conn.SetWriteDeadline(time.Now().Add(b.SendDeadline)) + _, err := conn.WriteTo(msg, dst) + conn.SetWriteDeadline(time.Time{}) + + if err != nil { + // only return err if it wasn't a timeout error + if err, ok := err.(net.Error); !ok || !err.Timeout() { + return err + } + } + } + + return nil +} + +// GetBroadcastIPs from the given list of addresses +func GetBroadcastIPs(addrs []net.Addr) (bcastIPs []net.IP) { + // IPv4 is enforced for broadcast, this will store the raw IPv4 address's value + bCastBytes := [4]byte{} + + for _, addr := range addrs { + ipAddr, ok := addr.(*net.IPNet) + + if ok && ipAddr.IP.IsGlobalUnicast() && ipAddr.IP.To4() != nil { + // per https://en.wikipedia.org/wiki/Broadcast_address#IP_networking + // The broadcast address for an IPv4 host can be obtained by performing + // a bitwise OR operation between the bit complement of the subnet mask + // and the host's IP address. In other words, take the host's IP address, + // and set to '1' any bit positions which hold a '0' in the subnet mask. + // For broadcasting a packet to an entire IPv4 subnet using the private IP + // address space 172.16.0.0/12, which has the subnet mask 255.240.0.0, the + // broadcast address is 172.16.0.0 | 0.15.255.255 = 172.31.255.255. + mask := ipAddr.Mask + mask = append(make([]byte, len(mask)-4), mask...) + + for i, num := range ipAddr.IP.To4() { + bCastBytes[i] = num | (^mask[i]) + } + + bcastIPs = append(bcastIPs, net.IPv4(bCastBytes[0], bCastBytes[1], bCastBytes[2], bCastBytes[3])) + } + } + + return +} + +// DefaultBroadcastIPs gets all available interface IPs for doing broadcasting +func DefaultBroadcastIPs() (bcastIPs []net.IP) { + addrs, _ := net.InterfaceAddrs() + bcastIPs = GetBroadcastIPs(addrs) + if bcastIPs == nil { + bcastIPs = append(bcastIPs, net.IPv4bcast) + } + + return +} + +func validateBcastListenAddr(listenAddr *net.UDPAddr) error { + if listenAddr == nil || listenAddr.IP.To4() == nil { + return ErrNotIPv4 + } + + if listenAddr.Port <= 0 { + return ErrZeroPort + } + + return nil +} diff --git a/local/broadcast_test.go b/local/broadcast_test.go new file mode 100644 index 0000000..9882a77 --- /dev/null +++ b/local/broadcast_test.go @@ -0,0 +1,266 @@ +package broadcast + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewBroadcastBadListenAddr(t *testing.T) { + assert := assert.New(t) + + badAddrs := []*net.UDPAddr{ + nil, + &net.UDPAddr{IP: net.IPv6loopback}, + &net.UDPAddr{IP: net.IPv4zero}, + } + + expectedErrs := []error{ + ErrNotIPv4, + ErrNotIPv4, + ErrZeroPort, + } + + for i, badAddr := range badAddrs { + b, err := NewBroadcast(badAddr, nil) + assert.Nil(b) + assert.Equal(expectedErrs[i], err) + } +} + +func TestNewBroadcastBroadcastIPSet(t *testing.T) { + assert := assert.New(t) + + inputs := [][]net.IP{ + nil, + DefaultBroadcastIPs(), + []net.IP{ + net.IPv4bcast, + }, + []net.IP{ + net.IPv4bcast, + net.IPv4allsys, + net.IPv4allrouter, + }, + } + + expectedOutputs := [][]net.IP{ + DefaultBroadcastIPs(), + inputs[1], + inputs[2], + inputs[3], + } + + laddr := &net.UDPAddr{IP: net.IPv4zero} + laddr.Port = 123 + + for i, input := range inputs { + b, err := NewBroadcast(laddr, input) + assert.NoError(err) + assert.NotNil(b) + + assert.Equal(expectedOutputs[i], b.bcastIPs) + } +} + +func TestBroadcastListenBadAddress(t *testing.T) { + assert := assert.New(t) + + b, err := NewBroadcast(&net.UDPAddr{IP: net.IPv4zero, Port: -1}, nil) + assert.Error(err) + assert.Nil(b) +} + +func TestBroadcastListenContextReturn(t *testing.T) { + assert := assert.New(t) + + laddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + b, err := NewBroadcast(laddr, nil) + assert.NoError(err) + assert.NotNil(b) + b.listenAddr.Port = 0 + b.ListenDeadline = time.Nanosecond + + ctx, cancel := context.WithCancel(context.Background()) + + errChan := make(chan error) + go func() { + errChan <- b.Listen(ctx, nil) + }() + + cancel() + assert.NoError(<-errChan) +} + +func TestBroadcastListenNonTimeoutError(t *testing.T) { + assert := assert.New(t) + + laddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + b, err := NewBroadcast(laddr, nil) + assert.NoError(err) + assert.NotNil(b) + b.listenAddr.Port = 0 + b.ListenDeadline = time.Nanosecond + + errChan := make(chan error) + go func() { + errChan <- b.Listen(context.Background(), nil) + }() + + // wait for the listening connectino to be created + for b.lconn == nil { + time.Sleep(time.Millisecond) + } + + b.lconn.Close() + + assert.Error(<-errChan) +} + +func TestBroadcastListenTimeoutErrorThenDone(t *testing.T) { + assert := assert.New(t) + + laddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + b, err := NewBroadcast(laddr, nil) + assert.NoError(err) + assert.NotNil(b) + b.listenAddr.Port = 0 + b.ListenDeadline = time.Nanosecond + + ctx, cancel := context.WithCancel(context.Background()) + + errChan := make(chan error) + go func() { + errChan <- b.Listen(ctx, nil) + }() + + time.Sleep(time.Millisecond * time.Duration(10)) + cancel() + + assert.NoError(<-errChan) +} + +func TestBroadcastListen(t *testing.T) { + assert := assert.New(t) + + laddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + b, err := NewBroadcast(laddr, nil) + assert.NoError(err) + assert.NotNil(b) + b.listenAddr.Port = 0 + b.ListenDeadline = time.Millisecond * time.Duration(10) + ctx, cancel := context.WithCancel(context.Background()) + + errChan := make(chan error) + msgChan := make(chan []byte) + go func() { + errChan <- b.Listen(ctx, msgChan) + }() + + // wait for the listening connectino to be created + for b.lconn == nil { + time.Sleep(time.Millisecond) + } + + conn, err := net.ListenUDP("udp4", nil) + assert.NoError(err) + defer conn.Close() + + expectedBytes := []byte("Hello there! GENERAL KENOBI!") + + _, err = conn.WriteTo(expectedBytes, b.lconn.LocalAddr()) + assert.NoError(err) + + gotBytes := <-msgChan + cancel() + assert.NoError(<-errChan) + assert.Equal(expectedBytes, gotBytes) +} + +func TestBroadcastSendBadmsgSize(t *testing.T) { + assert := assert.New(t) + + laddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + b, err := NewBroadcast(laddr, nil) + assert.NoError(err) + assert.NotNil(b) + + badMsgs := [][]byte{ + nil, + make([]byte, 0), + []byte{}, + make([]byte, MaxUDPPacketSize+1), + } + + for _, badMsg := range badMsgs { + assert.Equal(ErrBroadcastSendBadSize, b.Send(badMsg)) + } +} + +func TestBroadcastSend(t *testing.T) { + assert := assert.New(t) + + laddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + b, err := NewBroadcast(laddr, nil) + assert.NoError(err) + assert.NotNil(b) + + msgChan := make(chan []byte) + errChan := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + errChan <- b.Listen(ctx, msgChan) + }() + + // wait for the listening connectino to be created + for b.lconn == nil { + time.Sleep(time.Millisecond) + } + + bClient, err := NewBroadcast((b.lconn.LocalAddr()).(*net.UDPAddr), []net.IP{(b.lconn.LocalAddr()).(*net.UDPAddr).IP}) + assert.NoError(err) + + expectedBytes := []byte("hello there, GENERAL KENOBI!!!!") + assert.NoError(bClient.Send(expectedBytes)) + + assert.Equal(expectedBytes, <-msgChan) + cancel() +} + +func TestGetBroadcastIPs(t *testing.T) { + assert := assert.New(t) + + cidrToAddr := func(cidr string) net.Addr { + _, net, err := net.ParseCIDR(cidr) + assert.NoError(err) + + return net + } + + inputAddrs := []net.Addr{ + cidrToAddr("192.168.1.0/24"), + cidrToAddr("192.168.0.0/16"), + cidrToAddr("192.0.0.0/8"), + cidrToAddr("192.168.1.0/4"), + // the following should not generate results + cidrToAddr("127.0.0.1/24"), + &net.IPAddr{IP: net.IPv6linklocalallnodes}, + &net.IPAddr{IP: net.IPv4bcast}, + } + + expectedAddrs := []net.IP{ + net.IPv4(192, 168, 1, 255), + net.IPv4(192, 168, 255, 255), + net.IPv4(192, 255, 255, 255), + net.IPv4(207, 255, 255, 255), + } + + gotAddrs := GetBroadcastIPs(inputAddrs) + assert.Len(gotAddrs, len(expectedAddrs)) + for i, gotAddr := range gotAddrs { + assert.Equal(expectedAddrs[i], gotAddr) + } +} diff --git a/seeker.go b/seeker.go new file mode 100644 index 0000000..245a01d --- /dev/null +++ b/seeker.go @@ -0,0 +1,9 @@ +package discovery + +import "context" + +// Seeker defines a way to send/recv messages to potential peers +type Seeker interface { + Listen(ctx context.Context, msgChan chan<- []byte) error + Send(msg []byte) error +}