diff options
Diffstat (limited to 'gemini')
-rw-r--r-- | gemini/response.go | 33 | ||||
-rw-r--r-- | gemini/roundtrip_test.go | 4 | ||||
-rw-r--r-- | gemini/serve.go | 158 |
3 files changed, 53 insertions, 142 deletions
diff --git a/gemini/response.go b/gemini/response.go index 0452462..b8797da 100644 --- a/gemini/response.go +++ b/gemini/response.go @@ -6,6 +6,7 @@ import ( "errors" "io" "strconv" + "sync" "tildegit.org/tjp/gus" ) @@ -284,19 +285,17 @@ func ParseResponse(rdr io.Reader) (*gus.Response, error) { }, nil } -type ResponseReader interface { - io.Reader - io.WriterTo - io.Closer -} - -func NewResponseReader(response *gus.Response) ResponseReader { - return &responseReader{ Response: response } +func NewResponseReader(response *gus.Response) gus.ResponseReader { + return &responseReader{ + Response: response, + once: &sync.Once{}, + } } type responseReader struct { *gus.Response reader io.Reader + once *sync.Once } func (rdr *responseReader) Read(b []byte) (int, error) { @@ -310,16 +309,14 @@ func (rdr *responseReader) WriteTo(dst io.Writer) (int64, error) { } func (rdr *responseReader) ensureReader() { - if rdr.reader != nil { - return - } - - hdr := bytes.NewBuffer(rdr.headerLine()) - if rdr.Body != nil { - rdr.reader = io.MultiReader(hdr, rdr.Body) - } else { - rdr.reader = hdr - } + rdr.once.Do(func() { + hdr := bytes.NewBuffer(rdr.headerLine()) + if rdr.Body != nil { + rdr.reader = io.MultiReader(hdr, rdr.Body) + } else { + rdr.reader = hdr + } + }) } func (rdr responseReader) headerLine() []byte { diff --git a/gemini/roundtrip_test.go b/gemini/roundtrip_test.go index ab7baa4..a9d9b59 100644 --- a/gemini/roundtrip_test.go +++ b/gemini/roundtrip_test.go @@ -24,7 +24,7 @@ func TestRoundTrip(t *testing.T) { return gemini.Success("text/gemini", bytes.NewBufferString("you've found my page")) } - server, err := gemini.NewServer(context.Background(), nil, tlsConf, "tcp", "127.0.0.1:0", handler) + server, err := gemini.NewServer(context.Background(), "localhost", "tcp", "127.0.0.1:0", handler, nil, tlsConf) require.Nil(t, err) go func() { @@ -69,7 +69,7 @@ func TestTitanRequest(t *testing.T) { return gemini.Success("", nil) } - server, err := gemini.NewServer(context.Background(), nil, tlsConf, "tcp", "127.0.0.1:0", handler) + server, err := gemini.NewServer(context.Background(), "localhost", "tcp", "127.0.0.1:0", handler, nil, tlsConf) require.Nil(t, err) go func() { diff --git a/gemini/serve.go b/gemini/serve.go index abed257..55998d6 100644 --- a/gemini/serve.go +++ b/gemini/serve.go @@ -5,14 +5,13 @@ import ( "context" "crypto/tls" "errors" - "fmt" "io" "net" "strconv" "strings" - "sync" "tildegit.org/tjp/gus" + "tildegit.org/tjp/gus/internal" "tildegit.org/tjp/gus/logging" ) @@ -25,127 +24,59 @@ type titanRequestBodyKey struct{} var TitanRequestBody = titanRequestBodyKey{} type server struct { - ctx context.Context - errorLog logging.Logger - network string - address string - cancel context.CancelFunc - wg *sync.WaitGroup - listener net.Listener - handler gus.Handler + internal.Server + + handler gus.Handler } +func (s server) Protocol() string { return "GEMINI" } + // NewServer builds a gemini server. func NewServer( ctx context.Context, - errorLog logging.Logger, - tlsConfig *tls.Config, + hostname string, network string, address string, handler gus.Handler, + errorLog logging.Logger, + tlsConfig *tls.Config, ) (gus.Server, error) { - listener, err := net.Listen(network, address) - if err != nil { - return nil, err - } + s := &server{handler: handler} - addr := listener.Addr() - - s := &server{ - ctx: ctx, - errorLog: errorLog, - network: addr.Network(), - address: addr.String(), - wg: &sync.WaitGroup{}, - listener: tls.NewListener(listener, tlsConfig), - handler: handler, + if strings.IndexByte(hostname, ':') < 0 { + hostname = net.JoinHostPort(hostname, "1965") } - return s, nil -} - -// Serve starts the server and blocks until it is closed. -// -// This function will allocate resources which are not cleaned up until -// Close() is called. -// -// It will respect cancellation of the context the server was created with, -// but be aware that Close() must still be called in that case to avoid -// dangling goroutines. -// -// On titan protocol requests it sets a key/value pair in the context. The -// key is TitanRequestBody, and the value is a *bufio.Reader from which the -// request body can be read. -func (s *server) Serve() error { - s.wg.Add(1) - defer s.wg.Done() - - s.ctx, s.cancel = context.WithCancel(s.ctx) - - s.wg.Add(1) - s.propagateCancel() - - for { - conn, err := s.listener.Accept() - if err != nil { - if s.Closed() { - err = nil - } else { - _ = s.errorLog.Log("msg", "accept error", "error", err) - } - - return err - } - - s.wg.Add(1) - go s.handleConn(conn) + internalServer, err := internal.NewServer(ctx, hostname, network, address, errorLog, s.handleConn) + if err != nil { + return nil, err } -} - -func (s *server) Close() { - s.cancel() - s.wg.Wait() -} - -func (s *server) Network() string { - return s.network -} + s.Server = internalServer -func (s *server) Address() string { - return s.address -} + s.Listener = tls.NewListener(s.Listener, tlsConfig) -func (s *server) Hostname() string { - host, _, _ := net.SplitHostPort(s.address) - return host -} - -func (s *server) Port() string { - _, portStr, _ := net.SplitHostPort(s.address) - return portStr + return s, nil } func (s *server) handleConn(conn net.Conn) { - defer s.wg.Done() - defer conn.Close() - buf := bufio.NewReader(conn) var response *gus.Response - req, err := ParseRequest(buf) + request, err := ParseRequest(buf) if err != nil { response = BadRequest(err.Error()) } else { - req.Server = s - req.RemoteAddr = conn.RemoteAddr() + request.Server = s + request.RemoteAddr = conn.RemoteAddr() + if tlsconn, ok := conn.(*tls.Conn); ok { state := tlsconn.ConnectionState() - req.TLSState = &state + request.TLSState = &state } - ctx := s.ctx - if req.Scheme == "titan" { - len, err := sizeParam(req.Path) + ctx := s.Ctx + if request.Scheme == "titan" { + len, err := sizeParam(request.Path) if err == nil { ctx = context.WithValue( ctx, @@ -155,15 +86,16 @@ func (s *server) handleConn(conn net.Conn) { } } - defer func() { - if r := recover(); r != nil { - err := fmt.Errorf("%s", r) - _ = s.errorLog.Log("msg", "panic in handler", "err", err) - _, _ = io.Copy(conn, NewResponseReader(Failure(err))) - } - }() - - response = s.handler(ctx, req) + /* + defer func() { + if r := recover(); r != nil { + err := fmt.Errorf("%s", r) + _ = s.LogError("msg", "panic in handler", "err", err) + _, _ = io.Copy(conn, NewResponseReader(Failure(err))) + } + }() + */ + response = s.handler(ctx, request) if response == nil { response = NotFound("Resource does not exist.") } @@ -173,24 +105,6 @@ func (s *server) handleConn(conn net.Conn) { _, _ = io.Copy(conn, NewResponseReader(response)) } -func (s *server) propagateCancel() { - go func() { - defer s.wg.Done() - - <-s.ctx.Done() - _ = s.listener.Close() - }() -} - -func (s *server) Closed() bool { - select { - case <-s.ctx.Done(): - return true - default: - return false - } -} - func sizeParam(path string) (int, error) { _, rest, found := strings.Cut(path, ";") if !found { |