From 2c9c1d865c36c7b8a0a99534be106b98a0231f85 Mon Sep 17 00:00:00 2001 From: Tony Blyler Date: Mon, 11 Jun 2018 17:14:22 -0400 Subject: [PATCH] Initial commit --- client/client.go | 89 ++++++++++++++++++ main.go | 207 ++++++++++++++++++++++++++++++++++++++++++ server/server.go | 169 ++++++++++++++++++++++++++++++++++ signer/cleartext.go | 16 ++++ signer/cryptsigner.go | 7 ++ signer/secretbox.go | 56 ++++++++++++ 6 files changed, 544 insertions(+) create mode 100644 client/client.go create mode 100644 main.go create mode 100644 server/server.go create mode 100644 signer/cleartext.go create mode 100644 signer/cryptsigner.go create mode 100644 signer/secretbox.go diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..c5a5a4d --- /dev/null +++ b/client/client.go @@ -0,0 +1,89 @@ +package client + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "net" + "sync" + "sync/atomic" +) + +type ClientID uint64 + +var currentClientID uint64 + +func getID() ClientID { + return ClientID(atomic.AddUint64(¤tClientID, 1)) +} + +type Client struct { + conn net.Conn + id ClientID + ctx context.Context + cancel context.CancelFunc + writeLock sync.Mutex + readLock sync.Mutex +} + +func NewClient(conn net.Conn, ctx context.Context) *Client { + ctx, cancel := context.WithCancel(ctx) + return &Client{ + conn: conn, + id: getID(), + ctx: ctx, + cancel: cancel, + } +} + +func (c *Client) GetID() ClientID { + return c.id +} + +func (c *Client) SendMsg(msg []byte) error { + c.writeLock.Lock() + defer c.writeLock.Unlock() + + sizeBytes := make([]byte, 8) + binary.PutVarint(sizeBytes, int64(len(msg))) + + _, err := io.Copy(c.conn, bytes.NewReader(append(sizeBytes, msg...))) + return err +} + +func (c *Client) GetMsg() ([]byte, error) { + c.readLock.Lock() + defer c.readLock.Unlock() + + sizeBytes := make([]byte, 8) + _, err := io.ReadFull(c.conn, sizeBytes) + if err != nil { + return nil, err + } + + msgSize, errInt := binary.Varint(sizeBytes) + if errInt <= 0 || msgSize <= 0 { + return nil, errors.New("Failed to decode message size bytes") + } + + output := bytes.NewBuffer(make([]byte, msgSize)) + output.Reset() + + _, err = io.CopyN(output, c.conn, msgSize) + if err != nil { + return nil, err + } + + return output.Bytes(), nil +} + +func (c *Client) Done() bool { + return c.ctx.Err() != nil +} + +func (c *Client) Close() error { + c.cancel() + return c.conn.Close() +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..b07345b --- /dev/null +++ b/main.go @@ -0,0 +1,207 @@ +package main + +import ( + "bytes" + "context" + "encoding/base64" + "flag" + "fmt" + "io" + "net" + "os" + + "github.com/jroimartin/gocui" + "gitlab.com/tblyler/ep2pchat/client" + "gitlab.com/tblyler/ep2pchat/server" + "gitlab.com/tblyler/ep2pchat/signer" +) + +func logErrConn(conn net.Conn, msgs ...interface{}) { + fmt.Fprint(os.Stderr, "["+conn.RemoteAddr().String()+"] ") + fmt.Fprintln(os.Stderr, msgs...) +} + +func logConn(conn net.Conn, msgs ...interface{}) { + fmt.Print("[" + conn.RemoteAddr().String() + "] ") + fmt.Println(msgs...) +} + +func logErr(msgs ...interface{}) { + fmt.Fprintln(os.Stderr, msgs...) +} + +func log(msgs ...interface{}) { + fmt.Println(msgs...) +} + +func main() { + keyStr := "" + host := "0.0.0.0:0" + isServer := false + + flag.StringVar(&keyStr, "key", keyStr, "256-bit binary base64-encoded key value, generates otherwise") + flag.StringVar(&host, "host", host, "the host to connect to or listen on") + flag.BoolVar(&isServer, "server", isServer, "determines whether or not to act as a server") + flag.Parse() + + err := func() (err error) { + var sign signer.Signer + if keyStr == "" { + log("not using encryption nor signing") + sign = signer.NewClearText() + } else { + key, err := base64.StdEncoding.DecodeString(keyStr) + if err != nil || len(key) != signer.KeyLength { + logErr("Failed to decode 256-bit key from base64") + return err + } + + var keyArray [signer.KeyLength]byte + copy(keyArray[:], key) + + sign = signer.NewSecretBox(keyArray) + } + + if isServer { + listener, err := net.Listen("tcp", host) + if err != nil { + logErr("Failed to listen on address ", host) + return err + } + + defer listener.Close() + + log("Listening on ", listener.Addr().String()) + + server := server.NewServer(listener, sign, context.Background()) + return server.Serve() + } + + conn, err := net.Dial("tcp", host) + if err != nil { + logErr("Failed to connect to address ", host) + return err + } + + defer conn.Close() + + client := client.NewClient(conn, context.Background()) + + gui, err := gocui.NewGui(gocui.OutputNormal) + if err != nil { + logErr("Failed to initialze UI") + return err + } + defer gui.Close() + + gui.Cursor = true + gui.SetManagerFunc(func(gui *gocui.Gui) error { + maxX, maxY := gui.Size() + + v, err := gui.SetView("chat", 1, 1, maxX-1, maxY-5) + if err != nil { + if err != gocui.ErrUnknownView { + return err + } + + if _, err := gui.SetCurrentView("chat"); err != nil { + return err + } + + v.Editable = true + v.Wrap = true + } + v, err = gui.SetView("input", 1, maxY-4, maxX-1, maxY-1) + if err != nil { + if err != gocui.ErrUnknownView { + return err + } + + if _, err := gui.SetCurrentView("input"); err != nil { + return err + } + + v.Editable = true + v.Wrap = true + } + + return nil + }) + gui.SetKeybinding("", gocui.KeyCtrlC, gocui.ModNone, func(g *gocui.Gui, v *gocui.View) error { + return gocui.ErrQuit + }) + gui.SetKeybinding("input", gocui.KeyEnter, gocui.ModNone, func(g *gocui.Gui, v *gocui.View) error { + msg := v.ViewBuffer() + + if msg == "" { + return nil + } + + encMsg, err := sign.Encode([]byte(msg)) + if err != nil { + logErr(err) + return gocui.ErrQuit + } + + err = client.SendMsg(encMsg) + if err != nil { + logErr(err) + return gocui.ErrQuit + } + + v.SetCursor(0, 0) + v.Clear() + + return nil + }) + + go func() { + for { + encMsg, err := client.GetMsg() + if err != nil { + logErr(err) + break + } + + msg, err := sign.Decode(encMsg) + if err != nil { + logErr(err) + break + } + + gui.Update(func(gui *gocui.Gui) error { + v, err := gui.View("chat") + if err != nil { + gui.Close() + return err + } + + _, err = io.Copy(v, bytes.NewReader(msg)) + if err != nil { + logErr(err) + return err + } + + return nil + }) + if err != nil { + break + } + } + + gui.Close() + }() + + err = gui.MainLoop() + if err != nil && err != gocui.ErrQuit { + return err + } + + return nil + }() + + if err != nil { + logErr(err) + os.Exit(1) + } +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..040569b --- /dev/null +++ b/server/server.go @@ -0,0 +1,169 @@ +package server + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "sync" + + "gitlab.com/tblyler/ep2pchat/client" + "gitlab.com/tblyler/ep2pchat/signer" +) + +type Server struct { + listener net.Listener + signer signer.Signer + ctx context.Context + + clients map[client.ClientID]*client.Client + clientsLock sync.Mutex + msgChan chan []byte +} + +func NewServer( + listener net.Listener, + signer signer.Signer, + ctx context.Context, +) *Server { + return (&Server{ + clients: make(map[client.ClientID]*client.Client), + }).SetListener( + listener, + ).SetSigner( + signer, + ).SetContext( + ctx, + ) +} + +func (s *Server) SetListener(listener net.Listener) *Server { + if listener == nil { + file, _ := ioutil.TempFile("", "") + listener, _ = net.FileListener(file) + } + + s.listener = listener + return s +} + +func (s *Server) SetSigner(sign signer.Signer) *Server { + if sign == nil { + sign = signer.NewClearText() + } + + s.signer = sign + return s +} + +func (s *Server) SetContext(ctx context.Context) *Server { + if ctx == nil { + ctx = context.Background() + } + + s.ctx = ctx + return s +} + +func (s *Server) Serve() error { + for { + conn, err := s.listener.Accept() + if err != nil { + return err + } + + client := s.addClient(conn) + go func() { + for !client.Done() { + encMsg, err := client.GetMsg() + if err != nil { + break + } + + msg, err := s.signer.Decode(encMsg) + if err != nil { + break + } + + msg = append([]byte(fmt.Sprintf("%d: ", client.GetID())), msg...) + + s.BroadcastMsg(msg) + } + + s.clientsLock.Lock() + s.removeClient(client) + s.clientsLock.Unlock() + + s.BroadcastMsg([]byte(fmt.Sprintf("Client %d has disconnected\n", client.GetID()))) + }() + + s.BroadcastMsg([]byte(fmt.Sprintf("Client %d has connected\n", client.GetID()))) + } +} + +func (s *Server) Close() (errs []error) { + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + + clientCount := len(s.clients) + errChan := make(chan error, clientCount) + for _, c := range s.clients { + go func(client *client.Client) { + errChan <- client.Close() + }(c) + } + + for i := 0; i < clientCount; i++ { + err := <-errChan + if err != nil { + errs = append(errs, err) + } + } + + return +} + +func (s *Server) BroadcastMsg(msg []byte) error { + encMsg, err := s.signer.Encode(msg) + if err != nil { + return err + } + + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + + wg := sync.WaitGroup{} + for _, c := range s.clients { + wg.Add(1) + go func(client *client.Client) { + err := client.SendMsg(encMsg) + if err != nil { + s.removeClient(client) + wg.Done() + s.BroadcastMsg([]byte(fmt.Sprintf("Client %d has disconnected\n", client.GetID()))) + return + } + + wg.Done() + }(c) + } + + wg.Wait() + return nil +} + +func (s *Server) addClient(conn net.Conn) *client.Client { + s.clientsLock.Lock() + + client := client.NewClient(conn, s.ctx) + s.clients[client.GetID()] = client + + s.clientsLock.Unlock() + + return client +} + +func (s *Server) removeClient(client *client.Client) { + delete(s.clients, client.GetID()) + client.Close() +} diff --git a/signer/cleartext.go b/signer/cleartext.go new file mode 100644 index 0000000..7dcfa39 --- /dev/null +++ b/signer/cleartext.go @@ -0,0 +1,16 @@ +package signer + +type ClearText struct { +} + +func NewClearText() *ClearText { + return new(ClearText) +} + +func (c *ClearText) Encode(rawMsg []byte) ([]byte, error) { + return rawMsg, nil +} + +func (c *ClearText) Decode(encMsg []byte) ([]byte, error) { + return encMsg, nil +} diff --git a/signer/cryptsigner.go b/signer/cryptsigner.go new file mode 100644 index 0000000..10458f5 --- /dev/null +++ b/signer/cryptsigner.go @@ -0,0 +1,7 @@ +package signer + +// Signer implements encoding and decoding methods for messages +type Signer interface { + Encode(rawMsg []byte) ([]byte, error) + Decode(encMsg []byte) ([]byte, error) +} diff --git a/signer/secretbox.go b/signer/secretbox.go new file mode 100644 index 0000000..0477816 --- /dev/null +++ b/signer/secretbox.go @@ -0,0 +1,56 @@ +package signer + +import ( + "crypto/rand" + "errors" + "io" + + "golang.org/x/crypto/nacl/secretbox" +) + +const ( + NonceLength = 24 + KeyLength = 32 +) + +type Secretbox struct { + key [KeyLength]byte +} + +func NewSecretBox(key [KeyLength]byte) *Secretbox { + return &Secretbox{ + key: key, + } +} + +func (s *Secretbox) Encode(rawMsg []byte) (encMsg []byte, err error) { + nonce, err := generateNonce() + if err != nil { + return + } + + encMsg = secretbox.Seal(nonce[:], rawMsg, &nonce, &s.key) + return +} + +func (s *Secretbox) Decode(encMsg []byte) (msg []byte, err error) { + if len(encMsg) < NonceLength { + err = errors.New("Invalid message length for decode") + return + } + + var nonce [NonceLength]byte + copy(nonce[:], encMsg[:NonceLength]) + + msg, ok := secretbox.Open(nil, encMsg[NonceLength:], &nonce, &s.key) + if !ok { + err = errors.New("Failed to decode message") + } + + return +} + +func generateNonce() (nonce [NonceLength]byte, err error) { + _, err = io.ReadFull(rand.Reader, nonce[:]) + return nonce, err +}