diff options
-rw-r--r-- | internal/pathtree.go | 46 | ||||
-rw-r--r-- | internal/pathtree_test.go | 8 | ||||
-rw-r--r-- | router.go | 17 |
3 files changed, 45 insertions, 26 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go index 0ab6b81..bbb1f2c 100644 --- a/internal/pathtree.go +++ b/internal/pathtree.go @@ -80,9 +80,6 @@ func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] { func (sn *segmentNode[V]) Add(pattern []string, value V) { if len(pattern) == 0 { - if sn.value != nil { - panic("pattern already exists") - } sn.value = &value return } @@ -91,13 +88,15 @@ func (sn *segmentNode[V]) Add(pattern []string, value V) { // single-segment param-capturing wildcard ("/:username/") type wildcardNode[V any] struct { + prefix string param string value *V subtree subtree[V] } func makeWildcard[V any](pattern []string, value V) wildcardNode[V] { - node := wildcardNode[V]{param: pattern[0][1:]} + prefix, param, _ := strings.Cut(pattern[0], ":") + node := wildcardNode[V]{prefix: prefix, param: param} if len(pattern) == 1 { node.value = &value } else { @@ -119,7 +118,7 @@ func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { } if wn.param != "" { m.paramkey = wn.param - m.paramval = segments[0] + m.paramval = strings.TrimPrefix(segments[0], wn.prefix) } if len(segments) == 1 { @@ -131,9 +130,6 @@ func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { func (wn *wildcardNode[V]) Add(pattern []string, value V) { if len(pattern) == 0 { - if wn.value != nil { - panic("pattern already exists") - } wn.value = &value return } @@ -178,6 +174,9 @@ func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { } for _, wc := range st.wildcards { + if wc.prefix != "" && !strings.HasPrefix(segments[0], wc.prefix) { + continue + } candidate := wc.Match(segments, m) if best == nil || candidate.length > best.length { best = candidate @@ -198,22 +197,20 @@ func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { func (st *subtree[V]) Add(pattern []string, value V) { if len(pattern[0]) > 0 { - switch pattern[0][0] { - case '*': + switch { + case pattern[0][0] == '*': if len(pattern) > 1 { panic("invalid pattern: segments after *remainder") } - if st.remainder != nil { - panic("pattern already exists") - } st.remainder = &remainderNode[V]{ param: pattern[0][1:], value: value, } return - case ':': - child := st.wildcards.Find(pattern[0][1:]) + case strings.Contains(pattern[0], ":"): + prefix, param, _ := strings.Cut(pattern[0], ":") + child := st.wildcards.Find(prefix, param) if child != nil { child.Add(pattern[1:], value) } else { @@ -291,19 +288,30 @@ func (cs childSegments[V]) Find(label string) *segmentNode[V] { type childWildcards[V any] []wildcardNode[V] func (cw childWildcards[V]) Len() int { return len(cw) } -func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param } +func (cw childWildcards[V]) Less(i, j int) bool { return wildcardLess(cw[i], cw[j]) } func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] } -func (cw childWildcards[V]) Find(param string) *wildcardNode[V] { +func (cw childWildcards[V]) Find(prefix, param string) *wildcardNode[V] { i := sort.Search(len(cw), func(i int) bool { - return param <= cw[i].param + return wildcardLess(wildcardNode[V]{prefix: prefix, param: param}, cw[i]) }) - if i < len(cw) && cw[i].param == param { + if i < len(cw) && cw[i].prefix == prefix && cw[i].param == param { return &cw[i] } return nil } +func wildcardLess[V any](a, b wildcardNode[V]) bool { + switch { + case a.prefix < b.prefix: + return true + case a.prefix == b.prefix: + return a.param <= b.param + default: + return false + } +} + // linked list we build up as we match our way through the path segments // - if there is a parameter captured it goes in paramkey/paramval // - if there was a value it goes in value diff --git a/internal/pathtree_test.go b/internal/pathtree_test.go index e152e85..d9df757 100644 --- a/internal/pathtree_test.go +++ b/internal/pathtree_test.go @@ -34,6 +34,7 @@ func TestPathTree(t *testing.T) { {"/a/b", 3}, {"/c", 4}, {"/x/:y/z/*rest", 5}, + {"/f/prefix:y/z/*rest", 6}, }, paths: map[string]matchresult{ "/a": { @@ -66,6 +67,13 @@ func TestPathTree(t *testing.T) { "/": { failed: true, }, + "/f/mismatch/z/more": { + failed: true, + }, + "/f/prefixblargh/z/more": { + value: 6, + params: map[string]string{"y": "blargh", "rest": "more"}, + }, }, }, { @@ -2,6 +2,7 @@ package sliderule import ( "context" + "path" "strings" "tildegit.org/tjp/sliderule/internal" @@ -10,9 +11,11 @@ import ( // Router stores a mapping of request path patterns to handlers. // // Pattern may begin with "/" and then contain slash-delimited segments. -// - Segments beginning with colon (:) are wildcards and will match any path +// - Segments containing a colon (:) are wildcards and will match any path // segment at that location. It may optionally have a word after the colon, -// which will be the parameter name the path segment is captured into. +// which will be the parameter name the path segment is captured into. It +// may also optionally have text before the colon, in which case the pattern +// will not match unless the request path segment contains that prefix. // - Segments beginning with asterisk (*) are remainder wildcards. This must // come last and will capture any remainder of the path. It may have a name // after the asterisk which will be the parameter name. @@ -50,7 +53,7 @@ func (r Router) Handle(ctx context.Context, request *Request) *Response { return nil } - return handler.Handle(context.WithValue(ctx, routeParamsKey, params), request) + return handler.Handle(context.WithValue(ctx, RouteParamsKey, params), request) } // Handler builds a Handler @@ -83,8 +86,8 @@ func (r *Router) Mount(prefix string, subrouter *Router) { prefix = strings.TrimSuffix(prefix, "/") for _, subroute := range subrouter.tree.Routes() { - r.Route(prefix+"/"+subroute.Pattern, subroute.Value) - if subroute.Pattern == "" { + r.Route(path.Join(prefix, subroute.Pattern), subroute.Value) + if subroute.Pattern == "/" { r.Route(prefix, subroute.Value) } } @@ -111,7 +114,7 @@ func (r *Router) Use(mw Middleware) { // If Router was used but no parameters were captured in the pattern, it // returns a non-nil empty map. func RouteParams(ctx context.Context) map[string]string { - if m, ok := ctx.Value(routeParamsKey).(map[string]string); ok { + if m, ok := ctx.Value(RouteParamsKey).(map[string]string); ok { return m } return nil @@ -119,4 +122,4 @@ func RouteParams(ctx context.Context) map[string]string { type routeParamsKeyType struct{} -var routeParamsKey = routeParamsKeyType{} +var RouteParamsKey = routeParamsKeyType{} |