summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortjpcc <tjp@ctrl-c.club>2023-02-14 20:10:57 -0700
committertjpcc <tjp@ctrl-c.club>2023-02-14 20:13:25 -0700
commitfcf545c27c70fb795a631a42738c486c07092e83 (patch)
treef533c819c8f4a382f796235b739f8fece0313588
parent18d69173b4e23a2edd8c07c35f7a5b927587e6d7 (diff)
Router improvements.
- test coverage for Router, not just PathTree - Router.Mount() now flattens routes into the parent router - Router.Use() implemented to set middleware on a router itself
-rw-r--r--internal/pathtree.go48
-rw-r--r--router.go60
-rw-r--r--router_test.go80
3 files changed, 158 insertions, 30 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go
index 563e85c..7da4c2b 100644
--- a/internal/pathtree.go
+++ b/internal/pathtree.go
@@ -32,6 +32,15 @@ func (pt *PathTree[V]) Add(pattern string, value V) {
}
}
+type Route[V any] struct {
+ Pattern string
+ Value V
+}
+
+func (pt PathTree[V]) Routes() []Route[V] {
+ return pt.tree.routes()
+}
+
// pattern segment which must be a specific string ("/users/").
type segmentNode[V any] struct {
label string
@@ -216,6 +225,45 @@ func (st *subtree[V]) Add(pattern []string, value V) {
}
}
+func (st subtree[V]) routes() []Route[V] {
+ routes := []Route[V]{}
+ for _, seg := range st.segments {
+ if seg.value != nil {
+ routes = append(routes, Route[V]{
+ Pattern: seg.label,
+ Value: *seg.value,
+ })
+ }
+ for _, r := range seg.subtree.routes() {
+ r.Pattern = seg.label + "/" + r.Pattern
+ routes = append(routes, r)
+ }
+ }
+
+ for _, wc := range st.wildcards {
+ if wc.value != nil {
+ routes = append(routes, Route[V]{
+ Pattern: ":" + wc.param,
+ Value: *wc.value,
+ })
+ }
+ for _, r := range wc.subtree.routes() {
+ r.Pattern = ":" + wc.param + "/" + r.Pattern
+ routes = append(routes, r)
+ }
+ }
+
+ if st.remainder != nil {
+ rn := *st.remainder
+ routes = append(routes, Route[V]{
+ Pattern: "*" + rn.param,
+ Value: rn.value,
+ })
+ }
+
+ return routes
+}
+
type childSegments[V any] []segmentNode[V]
func (cs childSegments[V]) Len() int { return len(cs) }
diff --git a/router.go b/router.go
index 50cc41f..1d8e93d 100644
--- a/router.go
+++ b/router.go
@@ -2,8 +2,6 @@ package gus
import (
"context"
- "crypto/tls"
- "net/url"
"strings"
"tildegit.org/tjp/gus/internal"
@@ -27,11 +25,18 @@ import (
// The zero value is a usable Router which will fail to match any requst path.
type Router struct {
tree internal.PathTree[Handler]
+
+ middleware []Middleware
+ routeAdded bool
}
// Route adds a handler to the router under a path pattern.
-func (r Router) Route(pattern string, handler Handler) {
+func (r *Router) Route(pattern string, handler Handler) {
+ for i := len(r.middleware) - 1; i >= 0; i-- {
+ handler = r.middleware[i](handler)
+ }
r.tree.Add(pattern, handler)
+ r.routeAdded = true
}
// Handler matches against the request path and dipatches to a route handler.
@@ -59,6 +64,8 @@ func (r Router) Handler(ctx context.Context, request *Request) *Response {
}
// Match returns the matched handler and captured path parameters, or nils.
+//
+// The returned handlers will be wrapped with any middleware attached to the router.
func (r Router) Match(request *Request) (Handler, map[string]string) {
handler, params := r.tree.Match(request.Path)
if handler == nil {
@@ -72,19 +79,27 @@ func (r Router) Match(request *Request) (Handler, map[string]string) {
// The prefix pattern may include segment :wildcards, but no *remainder segment. The
// mounted sub-router should have patterns which only include the portion of the path
// after whatever was matched by the prefix pattern.
-func (r Router) Mount(prefix string, subrouter *Router) {
+func (r *Router) Mount(prefix string, subrouter *Router) {
prefix = strings.TrimSuffix(prefix, "/")
- r.Route(prefix+"/*"+subrouterPathKey, func(ctx context.Context, request *Request) *Response {
- r := cloneRequest(request)
- r.Path = "/" + RouteParams(ctx)[subrouterPathKey]
- return subrouter.Handler(ctx, r)
- })
-
- // TODO: better approach. the above works but it's a little hacky
- // - add a method to PathTree that returns all the registered patterns
- // and their associated handlers
- // - have Mount pull those out of the subrouter, prepend the prefix to
- // all its patterns, and re-add them to the parent router.
+
+ for _, subroute := range subrouter.tree.Routes() {
+ r.Route(prefix+"/"+subroute.Pattern, subroute.Value)
+ }
+}
+
+// Use attaches a middleware to the router.
+//
+// Any routes set on the router will have their handlers decorated by the attached
+// middlewares in reverse order (the first middleware attached will be the outer-most:
+// first to see requests and the last to see responses).
+//
+// Use will panic if Route or Mount have already been called on the router -
+// middlewares must be set before any routes.
+func (r *Router) Use(mw Middleware) {
+ if r.routeAdded {
+ panic("all middlewares must be added prior to adding routes")
+ }
+ r.middleware = append(r.middleware, mw)
}
// RouteParams gathers captured path parameters from the request context.
@@ -104,18 +119,3 @@ const subrouterPathKey = "subrouter_path"
type routeParamsKeyType struct{}
var routeParamsKey = routeParamsKeyType{}
-
-func cloneRequest(start *Request) *Request {
- end := &Request{}
- *end = *start
-
- end.URL = &url.URL{}
- *end.URL = *start.URL
-
- if start.TLSState != nil {
- end.TLSState = &tls.ConnectionState{}
- *end.TLSState = *start.TLSState
- }
-
- return end
-}
diff --git a/router_test.go b/router_test.go
new file mode 100644
index 0000000..6f9c915
--- /dev/null
+++ b/router_test.go
@@ -0,0 +1,80 @@
+package gus_test
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "tildegit.org/tjp/gus"
+ "tildegit.org/tjp/gus/gemini"
+)
+
+func h1(_ context.Context, _ *gus.Request) *gus.Response {
+ return gemini.Success("", &bytes.Buffer{})
+}
+
+func mw1(h gus.Handler) gus.Handler {
+ return func(ctx context.Context, req *gus.Request) *gus.Response {
+ resp := h(ctx, req)
+ resp.Body = io.MultiReader(resp.Body, bytes.NewBufferString("\nmiddleware 1"))
+ return resp
+ }
+}
+
+func mw2(h gus.Handler) gus.Handler {
+ return func(ctx context.Context, req *gus.Request) *gus.Response {
+ resp := h(ctx, req)
+ resp.Body = io.MultiReader(resp.Body, bytes.NewBufferString("\nmiddleware 2"))
+ return resp
+ }
+}
+
+func TestRouterUse(t *testing.T) {
+ r := &gus.Router{}
+ r.Use(mw1)
+ r.Use(mw2)
+ r.Route("/", h1)
+
+ request, err := gemini.ParseRequest(bytes.NewBufferString("/\r\n"))
+ require.Nil(t, err)
+
+ response := r.Handler(context.Background(), request)
+ require.NotNil(t, response)
+
+ body, err := io.ReadAll(response.Body)
+ require.Nil(t, err)
+
+ assert.Equal(t, "\nmiddleware 2\nmiddleware 1", string(body))
+}
+
+func TestRouterMount(t *testing.T) {
+ outer := &gus.Router{}
+ outer.Use(mw2)
+
+ inner := &gus.Router{}
+ inner.Use(mw1)
+ inner.Route("/bar", h1)
+
+ outer.Mount("/foo", inner)
+
+ request, err := gemini.ParseRequest(bytes.NewBufferString("/foo/bar\r\n"))
+ require.Nil(t, err)
+
+ response := outer.Handler(context.Background(), request)
+ require.NotNil(t, response)
+
+ body, err := io.ReadAll(response.Body)
+ require.Nil(t, err)
+
+ assert.Equal(t, "\nmiddleware 1\nmiddleware 2", string(body))
+
+ request, err = gemini.ParseRequest(bytes.NewBufferString("/foo\r\n"))
+ require.Nil(t, err)
+
+ response = outer.Handler(context.Background(), request)
+ require.Nil(t, response)
+}