package internal

import (
	"context"
	"net"
	"strings"
	"sync"

	"github.com/go-kit/log"
	"github.com/go-kit/log/level"
	"tildegit.org/tjp/sliderule/logging"
)

type Server struct {
	Ctx         context.Context
	Cancel      context.CancelFunc
	Wg          *sync.WaitGroup
	Listener    net.Listener
	HandleConn  connHandler
	ErrorLog    logging.Logger
	Host        string
	NetworkAddr net.Addr
}

type connHandler func(net.Conn)

func NewServer(
	ctx context.Context,
	hostname string,
	network string,
	address string,
	baseLog logging.Logger,
	handleConn connHandler,
) (Server, error) {
	listener, err := net.Listen(network, address)
	if err != nil {
		return Server{}, err
	}

	networkAddr := listener.Addr()
	ctx, cancel := context.WithCancel(ctx)

	if baseLog == nil {
		baseLog = log.NewNopLogger()
	}
	errlog := level.Error(baseLog)
	ctx = context.WithValue(ctx, "debuglog", level.Debug(baseLog))
	ctx = context.WithValue(ctx, "infolog", level.Info(baseLog))
	ctx = context.WithValue(ctx, "warnlog", level.Warn(baseLog))
	ctx = context.WithValue(ctx, "errlog", errlog)

	return Server{
		Ctx:         ctx,
		Cancel:      cancel,
		Wg:          &sync.WaitGroup{},
		Listener:    listener,
		HandleConn:  handleConn,
		ErrorLog:    errlog,
		Host:        hostname,
		NetworkAddr: networkAddr,
	}, nil
}

func (s *Server) Serve() error {
	s.Wg.Add(1)
	defer s.Wg.Done()

	s.propagateClose()

	for {
		conn, err := s.Listener.Accept()
		if err != nil {
			if s.Closed() {
				err = nil
			} else {
				_ = s.ErrorLog.Log("msg", "accept error", "error", err)
			}

			return err
		}

		s.Wg.Add(1)
		go func() {
			defer s.Wg.Done()
			defer func() { _ = conn.Close() }()
			s.HandleConn(conn)
		}()
	}
}

func (s *Server) Hostname() string {
	host, _, _ := net.SplitHostPort(s.Host)
	return host
}

func (s *Server) Port() string {
	_, port, _ := net.SplitHostPort(s.Host)
	return port
}

func (s *Server) Network() string {
	return s.NetworkAddr.Network()
}

func (s *Server) Address() string {
	return s.NetworkAddr.String()
}

func (s *Server) Close() {
	s.Cancel()
	s.Wg.Wait()
}

func (s *Server) LogError(keyvals ...any) error {
	return s.ErrorLog.Log(keyvals...)
}

func (s *Server) Closed() bool {
	select {
	case <-s.Ctx.Done():
		return true
	default:
		return false
	}
}

func (s *Server) propagateClose() {
	s.Wg.Add(1)
	go func() {
		defer s.Wg.Done()

		<-s.Ctx.Done()
		_ = s.Listener.Close()
	}()
}

// JoinDefaultPort appends ":<port>" iff the address does not already contain a port.
func JoinDefaultPort(address string, port string) string {
	if len(address) > 0 && address[0] == '[' {
		hend := strings.LastIndexByte(address, ']')
		if len(address) > hend+1 && address[hend+1] == ':' {
			return address
		}
		return net.JoinHostPort(address[1:hend], port)
	}

	if strings.Contains(address, ":") {
		return address
	}
	return net.JoinHostPort(address, port)
}