package main
// Lot's of learning right out of the gate:
// https://stackoverflow.com/questions/51472020/how-to-get-the-size-of-available-tcp-data
import (
  "bufio"
  "crypto/rand"
  "encoding/base64"
  "flag"
  "fmt"
  "io"
  "io/ioutil"
  "net"
  "net/http"
  "net/url"
  "os"
  "strconv"
  "strings"
  "sync"
  "time"
  "gopkg.in/yaml.v2"
)
type ConfMailer struct {
  Url string `yaml:"url,omitempty"`
  ApiKey string `yaml:"api_key,omitempty"`
  From string `yaml:"from,omitempty"`
}
type Conf struct {
  Port uint `yaml:"port,omitempty"`
  Mailer ConfMailer
}
type bufferedConn struct {
	r    *bufio.Reader
	rout io.Reader
	net.Conn
}
func newBufferedConn(c net.Conn) bufferedConn {
	return bufferedConn{bufio.NewReader(c), nil, c}
}
func (b bufferedConn) Peek(n int) ([]byte, error) {
	return b.r.Peek(n)
}
func (b bufferedConn) Buffered() (int) {
	return b.r.Buffered()
}
func (b bufferedConn) Read(p []byte) (int, error) {
	if b.rout != nil {
		return b.rout.Read(p)
	}
	return b.r.Read(p)
}
type myMsg struct {
  sender net.Conn
  bytes []byte
  receivedAt time.Time
  channel string
}
var firstMsgs chan myMsg
var myChans map[string](chan myMsg)
//var myMsgs chan myMsg
var myUnsortedConns map[net.Conn]bool
var myRawConns map[net.Conn]bool
var newConns chan net.Conn
func usage() {
  fmt.Fprintf(os.Stderr, "\nusage: go run chatserver.go\n")
  flag.PrintDefaults();
  fmt.Println()
  os.Exit(1)
}
// https://blog.questionable.services/article/generating-secure-random-numbers-crypto-rand/
func genAuthCode() (string, error) {
  n := 12
	b := make([]byte, n)
	_, err := rand.Read(b)
  // Note that err == nil only if we read len(b) bytes.
	if err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(b), nil
}
func handleRaw(conn bufferedConn) {
  // TODO
  // What happens if this is being read from range
  // when it's being added here (data race)?
  // Should I use a channel here instead?
  // TODO see https://jameshfisher.com/2017/04/18/golang-tcp-server.html
  var email string
  var code string
  var authn bool
  // Handle all subsequent packets
  buffer := make([]byte, 1024)
  for {
    fmt.Fprintf(os.Stdout, "[raw] Waiting for message...\n");
    count, err := conn.Read(buffer)
    if nil != err {
      if io.EOF != err {
        fmt.Fprintf(os.Stderr, "Non-EOF socket error: %s\n", err)
      }
      fmt.Fprintf(os.Stdout, "Ending socket\n")
      // TODO put this in a channel to prevent data races
      conn.Close();
      delete(myRawConns, conn)
      break
    }
    buf := buffer[:count]
    // Fun fact: if the buffer's current length (not capacity) is 0
    // then the Read returns 0 without error
    if 0 == count {
      fmt.Fprintf(os.Stdout, "Weird")
      break
    }
    if !authn {
      if "" == email {
        fmt.Fprintf(os.Stdout, "buf{%s}\n", buf[:count])
        // TODO use safer email testing
        email = strings.TrimSpace(string(buf[:count]))
        emailParts := strings.Split(email, "@")
        if 2 != len(emailParts) {
          fmt.Fprintf(conn, "Email: ")
          continue
        }
        fmt.Fprintf(os.Stdout, "email: '%v'\n", []byte(email))
        code, err = sendAuthCode(config.Mailer, strings.TrimSpace(email))
        if nil != err {
          // TODO handle better
          panic(err)
        }
        fmt.Fprintf(conn, "Auth Code: ")
        continue
      }
      if code != strings.TrimSpace(string(buf[:count])) {
        fmt.Fprintf(conn, "Incorrect Code\nAuth Code: ")
      } else {
        authn = true
        fmt.Fprintf(conn, "Welcome to #general! (TODO `/help' for list of commands)\n")
        // TODO number of users
        //fmt.Fprintf(conn, "Welcome to #general! TODO `/list' to see channels. `/join chname' to switch.\n")
      }
      continue
    }
    fmt.Fprintf(os.Stdout, "Queing message...\n");
    myChans["general"] <- myMsg{
      receivedAt: time.Now(),
      sender: conn,
      bytes: buf[0:count],
      channel: "general",
    }
  }
}
func handleSorted(conn bufferedConn) {
  // at this piont we've already at least one byte via Peek()
  // so the first packet is available in the buffer
  n := conn.Buffered()
  firstMsg, err := conn.Peek(n)
  if nil != err {
    panic(err)
  }
  firstMsgs <- myMsg{
    receivedAt: time.Now(),
    sender: conn,
    bytes: firstMsg,
    channel: "general",
  }
  // TODO
  // * TCP-CHAT
  // * HTTP
  // * TLS
  // Handle all subsequent packets
  buf := make([]byte, 1024)
  for {
    fmt.Fprintf(os.Stdout, "[sortable] Waiting for message...\n");
    count, err := conn.Read(buf)
    if nil != err {
      if io.EOF != err {
        fmt.Fprintf(os.Stderr, "Non-EOF socket error: %s\n", err)
      }
      fmt.Fprintf(os.Stdout, "Ending socket\n")
      break
    }
    // Fun fact: if the buffer's current length (not capacity) is 0
    // then the Read returns 0 without error
    if 0 == count {
      // fmt.Fprintf(os.Stdout, "Weird")
      continue
    }
    myChans["general"] <- myMsg{
      receivedAt: time.Now(),
      sender: conn,
      bytes: buf[0:count],
      channel: "general",
    }
  }
}
// TODO https://github.com/polvi/sni
func handleConnection(netConn net.Conn) {
  fmt.Fprintf(os.Stdout, "Accepting socket\n")
  m := sync.Mutex{}
  virgin := true
  // Why don't these work?
  //buf := make([]byte, 0, 1024)
  //buf := []byte{}
  // But this does
	bufConn := newBufferedConn(netConn)
  myUnsortedConns[bufConn] = true
  go func() {
    // Handle First Packet
    fmsg, err := bufConn.Peek(1)
    if nil != err {
      panic(err)
    }
    fmt.Fprintf(os.Stdout, "[First Byte] %s\n", fmsg)
    m.Lock();
    if virgin {
      virgin = false
      go handleSorted(bufConn)
    } else {
      // TODO probably needs to go into a channel
      myRawConns[bufConn] = true
      go handleRaw(bufConn)
    }
    m.Unlock();
  }()
  time.Sleep(250 * 1000000)
  // If we still haven't received data from the client
  // assume that the client must be expecting a welcome from us
  m.Lock()
  if virgin {
    virgin = false
    // don't block for this
    // let it be handled after the unlock
    defer fmt.Fprintf(netConn, "Welcome to Sample Chat! You appear to be using Telnet.\nYou must authenticate via email to participate\nEmail: ")
  }
  m.Unlock()
}
func sendAuthCode(cnf ConfMailer, to string) (string, error) {
  code, err := genAuthCode()
  if nil != err {
    return "", err
  }
  // TODO use go text templates with HTML escaping
  text := "Your authorization code:\n\n" + code
  html := "Your authorization code:
" + code
  // https://stackoverflow.com/questions/24493116/how-to-send-a-post-request-in-go
  // https://stackoverflow.com/questions/16673766/basic-http-auth-in-go
	client := http.Client{}
	form := url.Values{}
	form.Add("from", cnf.From)
	form.Add("to", to)
	form.Add("subject", "Sample Chat Auth Code: " + code)
	form.Add("text", text)
	form.Add("html", html)
	req, err := http.NewRequest("POST", cnf.Url, strings.NewReader(form.Encode()))
  if nil != err {
    return "", err
  }
	//req.PostForm = form
	req.Header.Add("User-Agent", "golang http.Client - Sample Chat App Authenticator")
	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
	req.SetBasicAuth("api", cnf.ApiKey)
  resp, err := client.Do(req)
  if nil != err {
    return "", err
  }
  defer resp.Body.Close()
  body, err := ioutil.ReadAll(resp.Body)
  if nil != err {
    return "", err
  }
  fmt.Fprintf(os.Stdout, "Here's what Mailgun had to say about the event: %s\n", body)
  return code, nil
}
var config Conf
func main() {
  flag.Usage = usage
  port := flag.Uint("telnet-port", 0, "tcp telnet chat port")
  confname := flag.String("conf", "./config.yml", "yaml config file")
  flag.Parse()
  confstr, err := ioutil.ReadFile(*confname)
  fmt.Fprintf(os.Stdout, "-conf=%s\n", *confname)
  if nil != err {
    fmt.Fprintf(os.Stderr, "%s\nUsing defaults instead\n", err)
    confstr = []byte("{\"port\":" + strconv.Itoa(int(*port)) + "}")
  }
  err = yaml.Unmarshal(confstr, &config)
  if nil != err {
    config = Conf{}
  }
  firstMsgs = make(chan myMsg, 128)
  //myMsgs = make(chan myMsg, 128)
  myChans = make(map[string](chan myMsg))
  newConns = make(chan net.Conn, 128)
  myRawConns = make(map[net.Conn]bool)
  myUnsortedConns = make(map[net.Conn]bool)
  // TODO dynamically select on channels?
  // https://stackoverflow.com/questions/19992334/how-to-listen-to-n-channels-dynamic-select-statement
  myChans["general"] = make(chan myMsg, 128)
  var addr string
  if 0 != int(*port) {
    addr = ":" + strconv.Itoa(int(*port))
  } else {
    addr = ":" + strconv.Itoa(int(config.Port))
  }
  // https://golang.org/pkg/net/#Conn
  sock, err := net.Listen("tcp", addr)
  if nil != err {
		fmt.Fprintf(os.Stderr, "Couldn't bind to TCP socket %q: %s\n", addr, err)
		os.Exit(2)
  }
  fmt.Println("Listening on", addr);
  go func() {
    for {
      conn, err := sock.Accept()
      if err != nil {
        // Not sure what kind of error this could be or how it could happen.
        // Could a connection abort or end before it's handled?
        fmt.Fprintf(os.Stderr, "Error accepting connection:\n%s\n", err)
      }
      newConns <- conn
    }
  }()
  for {
    select {
    case conn := <- newConns:
      ts := time.Now()
      fmt.Fprintf(os.Stdout, "[Handle New Connection] [Timestamp] %s\n", ts)
      go handleConnection(conn)
    case msg := <- myChans["general"]:
      ts, err := msg.receivedAt.MarshalJSON()
      if nil != err {
        fmt.Fprintf(os.Stderr, "[Error] %s\n", err)
      }
      fmt.Fprintf(os.Stdout, "[Timestamp] %s\n", ts)
      fmt.Fprintf(os.Stdout, "[Remote] %s\n", msg.sender.RemoteAddr().String())
      fmt.Fprintf(os.Stdout, "[Message] %s\n", msg.bytes);
      for conn, _ := range myRawConns {
        if msg.sender == conn {
          continue
        }
        // backlogged connections could prevent a next write,
        // so this should be refactored into a goroutine
        // And what to do about slow clients that get behind (or DoS)?
        // SetDeadTime and Disconnect them?
        conn.Write(msg.bytes)
      }
    case msg := <- firstMsgs:
      fmt.Fprintf(os.Stdout, "f [First Message]\n")
      ts, err := msg.receivedAt.MarshalJSON()
      if nil != err {
        fmt.Fprintf(os.Stderr, "f [Error] %s\n", err)
      }
      fmt.Fprintf(os.Stdout, "f [Timestamp] %s\n", ts)
      fmt.Fprintf(os.Stdout, "f [Remote] %s\n", msg.sender.RemoteAddr().String())
      fmt.Fprintf(os.Stdout, "f [Message] %s\n", msg.bytes);
    }
  }
}