diff options
author | tjpcc <tjp@ctrl-c.club> | 2023-02-02 16:15:53 -0700 |
---|---|---|
committer | tjpcc <tjp@ctrl-c.club> | 2023-02-02 16:15:53 -0700 |
commit | ac024567e880f0da59557f0051018f4ac932c6ad (patch) | |
tree | a1d0d5661d5f980ed56716e430a4e4de4b5e7bd3 | |
parent | b7cb13b4e68568b014c868186f536439e92d662f (diff) |
Initial Router work.
- Router type, supports: adding handlers, serving, fetching the matching
handler for a route.
- Private PathTree type handles the modified radix trie.
-rw-r--r-- | internal/pathtree.go | 275 | ||||
-rw-r--r-- | internal/pathtree_test.go | 103 | ||||
-rw-r--r-- | internal/server.go | 10 | ||||
-rw-r--r-- | router.go | 71 |
4 files changed, 455 insertions, 4 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go new file mode 100644 index 0000000..563e85c --- /dev/null +++ b/internal/pathtree.go @@ -0,0 +1,275 @@ +package internal + +import ( + "sort" + "strings" +) + +type PathTree[V any] struct { + tree subtree[V] + root *V +} + +func (pt PathTree[V]) Match(path string) (*V, map[string]string) { + path = strings.TrimPrefix(path, "/") + if path == "" { + return pt.root, map[string]string{} + } + m := pt.tree.Match(strings.Split(path, "/"), nil) + if m == nil { + return nil, nil + } + + return m.value, m.params() +} + +func (pt *PathTree[V]) Add(pattern string, value V) { + pattern = strings.TrimPrefix(pattern, "/") + if pattern == "" { + pt.root = &value + } else { + pt.tree.Add(strings.Split(pattern, "/"), value) + } +} + +// pattern segment which must be a specific string ("/users/"). +type segmentNode[V any] struct { + label string + value *V + subtree subtree[V] +} + +func makeSegment[V any](pattern []string, value V) segmentNode[V] { + node := segmentNode[V]{label: pattern[0]} + if len(pattern) == 1 { + node.value = &value + } else { + node.subtree.Add(pattern[1:], value) + } + return node +} + +func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] { + var l int + if m != nil { + l = m.length + } + m = &match[V]{ + value: sn.value, + length: l + len(sn.label), + prev: m, + } + if len(segments) == 1 { + return m + } + return sn.subtree.Match(segments[1:], m) +} + +func (sn *segmentNode[V]) Add(pattern []string, value V) { + if len(pattern) == 0 { + if sn.value != nil { + panic("pattern already exists") + } + sn.value = &value + return + } + sn.subtree.Add(pattern, value) +} + +// single-segment param-capturing wildcard ("/:username/") +type wildcardNode[V any] struct { + param string + value *V + subtree subtree[V] +} + +func makeWildcard[V any](pattern []string, value V) wildcardNode[V] { + node := wildcardNode[V]{param: pattern[0][1:]} + if len(pattern) == 1 { + node.value = &value + } else { + node.subtree.Add(pattern[1:], value) + } + return node +} + +func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { + var l int + if m != nil { + l = m.length + } + + m = &match[V]{ + value: wn.value, + length: l + len(segments[0]), + prev: m, + } + if wn.param != "" { + m.paramkey = wn.param + m.paramval = segments[0] + } + + if len(segments) == 1 { + return m + } + + return wn.subtree.Match(segments[1:], m) +} + +func (wn *wildcardNode[V]) Add(pattern []string, value V) { + if len(pattern) == 0 { + if wn.value != nil { + panic("pattern already exists") + } + wn.value = &value + return + } + wn.subtree.Add(pattern, value) +} + +// "rest of the path" capturing node ("/*path_info") - always a final node +type remainderNode[V any] struct { + param string + value V +} + +func (rn remainderNode[V]) Match(segments []string, m *match[V]) *match[V] { + m = &match[V]{ + value: &rn.value, + length: m.length, + prev: m, + } + if rn.param != "" { + m.paramkey = rn.param + m.paramval = strings.Join(segments, "/") + } + return m +} + +// all the children under a tree node +type subtree[V any] struct { + segments childSegments[V] + wildcards childWildcards[V] + remainder *remainderNode[V] +} + +func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { + var best *match[V] + + if st.remainder != nil { + best = st.remainder.Match(segments, m) + } + + for _, wc := range st.wildcards { + candidate := wc.Match(segments, m) + if best == nil || candidate.length > best.length { + best = candidate + } + } + + childSeg := st.segments.Find(segments[0]) + if childSeg == nil { + return best + } + candidate := childSeg.Match(segments, m) + if best == nil || (candidate != nil && candidate.length >= best.length) { + best = candidate + } + + return best +} + +func (st *subtree[V]) Add(pattern []string, value V) { + if len(pattern[0]) == 0 { + panic("invalid pattern") + } + + switch pattern[0][0] { + case '*': + if len(pattern) > 1 { + panic("invalid pattern: segments after *remainder") + } + if st.remainder != nil { + panic("pattern already exists") + } + + st.remainder = &remainderNode[V]{ + param: pattern[0][1:], + value: value, + } + case ':': + child := st.wildcards.Find(pattern[0][1:]) + if child != nil { + child.Add(pattern[1:], value) + } else { + st.wildcards = append(st.wildcards, makeWildcard(pattern, value)) + sort.Sort(st.wildcards) + } + default: + child := st.segments.Find(pattern[0]) + if child != nil { + child.Add(pattern[1:], value) + } else { + st.segments = append(st.segments, makeSegment(pattern, value)) + sort.Sort(st.segments) + } + } +} + +type childSegments[V any] []segmentNode[V] + +func (cs childSegments[V]) Len() int { return len(cs) } +func (cs childSegments[V]) Less(i, j int) bool { return cs[i].label < cs[j].label } +func (cs childSegments[V]) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } + +func (cs childSegments[V]) Find(label string) *segmentNode[V] { + idx := sort.Search(len(cs), func(i int) bool { + return label <= cs[i].label + }) + if idx < len(cs) && cs[idx].label == label { + return &cs[idx] + } + return nil +} + +type childWildcards[V any] []wildcardNode[V] + +func (cw childWildcards[V]) Len() int { return len(cw) } +func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param } +func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] } + +func (cw childWildcards[V]) Find(param string) *wildcardNode[V] { + i := sort.Search(len(cw), func(i int) bool { + return param <= cw[i].param + }) + if i < len(cw) && cw[i].param == param { + return &cw[i] + } + return nil +} + +// linked list we build up as we match our way through the path segments +// - if there is a parameter captured it goes in paramkey/paramval +// - if there was a value it goes in value +// - also store cumulative length to choose the longest match +type match[V any] struct { + paramkey string + paramval string + + value *V + length int + prev *match[V] +} + +func (m match[_]) params() map[string]string { + mch := &m + out := make(map[string]string) + for mch != nil { + if mch.paramkey != "" { + out[mch.paramkey] = mch.paramval + } + mch = mch.prev + } + + return out +} diff --git a/internal/pathtree_test.go b/internal/pathtree_test.go new file mode 100644 index 0000000..11f6848 --- /dev/null +++ b/internal/pathtree_test.go @@ -0,0 +1,103 @@ +package internal_test + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "tildegit.org/tjp/gus/internal" +) + +func TestPathTree(t *testing.T) { + type pattern struct { + string + int + } + type matchresult struct { + value int + params map[string]string + failed bool + } + tests := []struct { + // path-matching pattern and the integers the tree stores for each + patterns []pattern + + // paths to match against, and the int we get and captured params + paths map[string]matchresult + }{ + { + patterns: []pattern{ + {"/a", 1}, + {"/a/*rest", 2}, + {"/a/b", 3}, + {"/c", 4}, + {"/x/:y/z/*rest", 5}, + }, + paths: map[string]matchresult{ + "/a": { + value: 1, + params: map[string]string{}, + }, + "/a/other": { + value: 2, + params: map[string]string{"rest": "other"}, + }, + "/a/b": { + value: 3, + params: map[string]string{}, + }, + "/a/b/c": { + value: 2, + params: map[string]string{"rest": "b/c"}, + }, + "/c": { + value: 4, + params: map[string]string{}, + }, + "/c/d": { + failed: true, + }, + "/x/foo/z/bar/baz": { + value: 5, + params: map[string]string{"y": "foo", "rest": "bar/baz"}, + }, + "/": { + failed: true, + }, + }, + }, + { + patterns: []pattern{ + {"/", 10}, + }, + paths: map[string]matchresult{ + "/": {value: 10, params: map[string]string{}}, + "/foo": {failed: true}, + }, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i+1), func(t *testing.T) { + tree := &internal.PathTree[int]{} + for _, pattern := range test.patterns { + tree.Add(pattern.string, pattern.int) + } + + for path, result := range test.paths { + t.Run(path, func(t *testing.T) { + n, params := tree.Match(path) + if result.failed { + require.Nil(t, n) + } else { + require.NotNil(t, n) + assert.Equal(t, result.value, *n) + assert.Equal(t, result.params, params) + } + }) + } + }) + } +} diff --git a/internal/server.go b/internal/server.go index 38e478c..3efdf6e 100644 --- a/internal/server.go +++ b/internal/server.go @@ -4,17 +4,19 @@ import ( "context" "net" "sync" - - "tildegit.org/tjp/gus/logging" ) +type logger interface { + Log(keyvals ...any) error +} + type Server struct { Ctx context.Context Cancel context.CancelFunc Wg *sync.WaitGroup Listener net.Listener HandleConn connHandler - ErrorLog logging.Logger + ErrorLog logger Host string NetworkAddr net.Addr } @@ -26,7 +28,7 @@ func NewServer( hostname string, network string, address string, - errorLog logging.Logger, + errorLog logger, handleConn connHandler, ) (Server, error) { listener, err := net.Listen(network, address) diff --git a/router.go b/router.go new file mode 100644 index 0000000..8408246 --- /dev/null +++ b/router.go @@ -0,0 +1,71 @@ +package gus + +import ( + "context" + + "tildegit.org/tjp/gus/internal" +) + +// Router stores a mapping of request path patterns to handlers. +// +// Pattern may begin with "/" and then contain slash-delimited segments. +// - Segments beginning with colon (:) are wildcards and will match any path +// segment at that location. It may optionally have a word after the colon, +// which will be the parameter name the path segment is captured into. +// - Segments beginning with asterisk (*) are remainder wildcards. This must +// come last and will capture any remainder of the path. It may have a name +// after the asterisk which will be the parameter name. +// - Any other segment in the pattern must match a path segment exactly. +// +// These patterns do not match any path which shares a prefix, rather then +// full path must match a pattern. If you want to only match a prefix of the +// path you can end the pattern with a *remainder segment. +// +// The zero value is a usable Router which will fail to match any requst path. +type Router struct { + tree internal.PathTree[Handler] +} + +// Route adds a handler to the router under a path pattern. +func (r Router) Route(pattern string, handler Handler) { + r.tree.Add(pattern, handler) +} + +// Handler matches against the request path and dipatches to a route handler. +// +// If no route matches, it returns a nil response. +// Captured path parameters will be stored in the context passed into the handler +// and can be retrieved with RouteParams(). +func (r Router) Handler(ctx context.Context, request *Request) *Response { + handler, params := r.Match(request) + if handler == nil { + return nil + } + + return handler(context.WithValue(ctx, routeParamsKey, params), request) +} + +// Match returns the matched handler and captured path parameters, or nils. +func (r Router) Match(request *Request) (Handler, map[string]string) { + handler, params := r.tree.Match(request.Path) + if handler == nil { + return nil, nil + } + return *handler, params +} + +// RouteParams gathers captured path parameters from the request context. +// +// If the context doesn't contain a parameter map, it returns nil. +// If Router was used but no parameters were captured in the pattern, it +// returns a non-nil empty map. +func RouteParams(ctx context.Context) map[string]string { + if m, ok := ctx.Value(routeParamsKey).(map[string]string); ok { + return m + } + return nil +} + +type routeParamsKeyType struct{} + +var routeParamsKey = routeParamsKeyType{} |