summaryrefslogtreecommitdiff
path: root/gemini
diff options
context:
space:
mode:
Diffstat (limited to 'gemini')
-rw-r--r--gemini/response.go33
-rw-r--r--gemini/roundtrip_test.go4
-rw-r--r--gemini/serve.go158
3 files changed, 53 insertions, 142 deletions
diff --git a/gemini/response.go b/gemini/response.go
index 0452462..b8797da 100644
--- a/gemini/response.go
+++ b/gemini/response.go
@@ -6,6 +6,7 @@ import (
"errors"
"io"
"strconv"
+ "sync"
"tildegit.org/tjp/gus"
)
@@ -284,19 +285,17 @@ func ParseResponse(rdr io.Reader) (*gus.Response, error) {
}, nil
}
-type ResponseReader interface {
- io.Reader
- io.WriterTo
- io.Closer
-}
-
-func NewResponseReader(response *gus.Response) ResponseReader {
- return &responseReader{ Response: response }
+func NewResponseReader(response *gus.Response) gus.ResponseReader {
+ return &responseReader{
+ Response: response,
+ once: &sync.Once{},
+ }
}
type responseReader struct {
*gus.Response
reader io.Reader
+ once *sync.Once
}
func (rdr *responseReader) Read(b []byte) (int, error) {
@@ -310,16 +309,14 @@ func (rdr *responseReader) WriteTo(dst io.Writer) (int64, error) {
}
func (rdr *responseReader) ensureReader() {
- if rdr.reader != nil {
- return
- }
-
- hdr := bytes.NewBuffer(rdr.headerLine())
- if rdr.Body != nil {
- rdr.reader = io.MultiReader(hdr, rdr.Body)
- } else {
- rdr.reader = hdr
- }
+ rdr.once.Do(func() {
+ hdr := bytes.NewBuffer(rdr.headerLine())
+ if rdr.Body != nil {
+ rdr.reader = io.MultiReader(hdr, rdr.Body)
+ } else {
+ rdr.reader = hdr
+ }
+ })
}
func (rdr responseReader) headerLine() []byte {
diff --git a/gemini/roundtrip_test.go b/gemini/roundtrip_test.go
index ab7baa4..a9d9b59 100644
--- a/gemini/roundtrip_test.go
+++ b/gemini/roundtrip_test.go
@@ -24,7 +24,7 @@ func TestRoundTrip(t *testing.T) {
return gemini.Success("text/gemini", bytes.NewBufferString("you've found my page"))
}
- server, err := gemini.NewServer(context.Background(), nil, tlsConf, "tcp", "127.0.0.1:0", handler)
+ server, err := gemini.NewServer(context.Background(), "localhost", "tcp", "127.0.0.1:0", handler, nil, tlsConf)
require.Nil(t, err)
go func() {
@@ -69,7 +69,7 @@ func TestTitanRequest(t *testing.T) {
return gemini.Success("", nil)
}
- server, err := gemini.NewServer(context.Background(), nil, tlsConf, "tcp", "127.0.0.1:0", handler)
+ server, err := gemini.NewServer(context.Background(), "localhost", "tcp", "127.0.0.1:0", handler, nil, tlsConf)
require.Nil(t, err)
go func() {
diff --git a/gemini/serve.go b/gemini/serve.go
index abed257..55998d6 100644
--- a/gemini/serve.go
+++ b/gemini/serve.go
@@ -5,14 +5,13 @@ import (
"context"
"crypto/tls"
"errors"
- "fmt"
"io"
"net"
"strconv"
"strings"
- "sync"
"tildegit.org/tjp/gus"
+ "tildegit.org/tjp/gus/internal"
"tildegit.org/tjp/gus/logging"
)
@@ -25,127 +24,59 @@ type titanRequestBodyKey struct{}
var TitanRequestBody = titanRequestBodyKey{}
type server struct {
- ctx context.Context
- errorLog logging.Logger
- network string
- address string
- cancel context.CancelFunc
- wg *sync.WaitGroup
- listener net.Listener
- handler gus.Handler
+ internal.Server
+
+ handler gus.Handler
}
+func (s server) Protocol() string { return "GEMINI" }
+
// NewServer builds a gemini server.
func NewServer(
ctx context.Context,
- errorLog logging.Logger,
- tlsConfig *tls.Config,
+ hostname string,
network string,
address string,
handler gus.Handler,
+ errorLog logging.Logger,
+ tlsConfig *tls.Config,
) (gus.Server, error) {
- listener, err := net.Listen(network, address)
- if err != nil {
- return nil, err
- }
+ s := &server{handler: handler}
- addr := listener.Addr()
-
- s := &server{
- ctx: ctx,
- errorLog: errorLog,
- network: addr.Network(),
- address: addr.String(),
- wg: &sync.WaitGroup{},
- listener: tls.NewListener(listener, tlsConfig),
- handler: handler,
+ if strings.IndexByte(hostname, ':') < 0 {
+ hostname = net.JoinHostPort(hostname, "1965")
}
- return s, nil
-}
-
-// Serve starts the server and blocks until it is closed.
-//
-// This function will allocate resources which are not cleaned up until
-// Close() is called.
-//
-// It will respect cancellation of the context the server was created with,
-// but be aware that Close() must still be called in that case to avoid
-// dangling goroutines.
-//
-// On titan protocol requests it sets a key/value pair in the context. The
-// key is TitanRequestBody, and the value is a *bufio.Reader from which the
-// request body can be read.
-func (s *server) Serve() error {
- s.wg.Add(1)
- defer s.wg.Done()
-
- s.ctx, s.cancel = context.WithCancel(s.ctx)
-
- s.wg.Add(1)
- s.propagateCancel()
-
- for {
- conn, err := s.listener.Accept()
- if err != nil {
- if s.Closed() {
- err = nil
- } else {
- _ = s.errorLog.Log("msg", "accept error", "error", err)
- }
-
- return err
- }
-
- s.wg.Add(1)
- go s.handleConn(conn)
+ internalServer, err := internal.NewServer(ctx, hostname, network, address, errorLog, s.handleConn)
+ if err != nil {
+ return nil, err
}
-}
-
-func (s *server) Close() {
- s.cancel()
- s.wg.Wait()
-}
-
-func (s *server) Network() string {
- return s.network
-}
+ s.Server = internalServer
-func (s *server) Address() string {
- return s.address
-}
+ s.Listener = tls.NewListener(s.Listener, tlsConfig)
-func (s *server) Hostname() string {
- host, _, _ := net.SplitHostPort(s.address)
- return host
-}
-
-func (s *server) Port() string {
- _, portStr, _ := net.SplitHostPort(s.address)
- return portStr
+ return s, nil
}
func (s *server) handleConn(conn net.Conn) {
- defer s.wg.Done()
- defer conn.Close()
-
buf := bufio.NewReader(conn)
var response *gus.Response
- req, err := ParseRequest(buf)
+ request, err := ParseRequest(buf)
if err != nil {
response = BadRequest(err.Error())
} else {
- req.Server = s
- req.RemoteAddr = conn.RemoteAddr()
+ request.Server = s
+ request.RemoteAddr = conn.RemoteAddr()
+
if tlsconn, ok := conn.(*tls.Conn); ok {
state := tlsconn.ConnectionState()
- req.TLSState = &state
+ request.TLSState = &state
}
- ctx := s.ctx
- if req.Scheme == "titan" {
- len, err := sizeParam(req.Path)
+ ctx := s.Ctx
+ if request.Scheme == "titan" {
+ len, err := sizeParam(request.Path)
if err == nil {
ctx = context.WithValue(
ctx,
@@ -155,15 +86,16 @@ func (s *server) handleConn(conn net.Conn) {
}
}
- defer func() {
- if r := recover(); r != nil {
- err := fmt.Errorf("%s", r)
- _ = s.errorLog.Log("msg", "panic in handler", "err", err)
- _, _ = io.Copy(conn, NewResponseReader(Failure(err)))
- }
- }()
-
- response = s.handler(ctx, req)
+ /*
+ defer func() {
+ if r := recover(); r != nil {
+ err := fmt.Errorf("%s", r)
+ _ = s.LogError("msg", "panic in handler", "err", err)
+ _, _ = io.Copy(conn, NewResponseReader(Failure(err)))
+ }
+ }()
+ */
+ response = s.handler(ctx, request)
if response == nil {
response = NotFound("Resource does not exist.")
}
@@ -173,24 +105,6 @@ func (s *server) handleConn(conn net.Conn) {
_, _ = io.Copy(conn, NewResponseReader(response))
}
-func (s *server) propagateCancel() {
- go func() {
- defer s.wg.Done()
-
- <-s.ctx.Done()
- _ = s.listener.Close()
- }()
-}
-
-func (s *server) Closed() bool {
- select {
- case <-s.ctx.Done():
- return true
- default:
- return false
- }
-}
-
func sizeParam(path string) (int, error) {
_, rest, found := strings.Cut(path, ";")
if !found {