package internal import ( "context" "net" "strings" "sync" ) type logger interface { Log(keyvals ...any) error } type Server struct { Ctx context.Context Cancel context.CancelFunc Wg *sync.WaitGroup Listener net.Listener HandleConn connHandler ErrorLog logger Host string NetworkAddr net.Addr } type connHandler func(net.Conn) func NewServer( ctx context.Context, hostname string, network string, address string, errorLog logger, handleConn connHandler, ) (Server, error) { listener, err := net.Listen(network, address) if err != nil { return Server{}, err } networkAddr := listener.Addr() ctx, cancel := context.WithCancel(ctx) return Server{ Ctx: ctx, Cancel: cancel, Wg: &sync.WaitGroup{}, Listener: listener, HandleConn: handleConn, ErrorLog: errorLog, Host: hostname, NetworkAddr: networkAddr, }, nil } func (s *Server) Serve() error { s.Wg.Add(1) defer s.Wg.Done() s.propagateClose() for { conn, err := s.Listener.Accept() if err != nil { if s.Closed() { err = nil } else { _ = s.ErrorLog.Log("msg", "accept error", "error", err) } return err } s.Wg.Add(1) go func() { defer s.Wg.Done() defer func() { _ = conn.Close() }() s.HandleConn(conn) }() } } func (s *Server) Hostname() string { host, _, _ := net.SplitHostPort(s.Host) return host } func (s *Server) Port() string { _, port, _ := net.SplitHostPort(s.Host) return port } func (s *Server) Network() string { return s.NetworkAddr.Network() } func (s *Server) Address() string { return s.NetworkAddr.String() } func (s *Server) Close() { s.Cancel() s.Wg.Wait() } func (s *Server) LogError(keyvals ...any) error { return s.ErrorLog.Log(keyvals...) } func (s *Server) Closed() bool { select { case <-s.Ctx.Done(): return true default: return false } } func (s *Server) propagateClose() { s.Wg.Add(1) go func() { defer s.Wg.Done() <-s.Ctx.Done() _ = s.Listener.Close() }() } // JoinDefaultPort appends ":" iff the address does not already contain a port. func JoinDefaultPort(address string, port string) string { if address[0] == '[' { hend := strings.LastIndexByte(address, ']') if len(address) > hend+1 && address[hend+1] == ':' { return address } return net.JoinHostPort(address[1:hend], port) } if strings.Contains(address, ":") { return address } return net.JoinHostPort(address, port) }