summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/pathtree.go275
-rw-r--r--internal/pathtree_test.go103
-rw-r--r--internal/server.go10
-rw-r--r--router.go71
4 files changed, 455 insertions, 4 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go
new file mode 100644
index 0000000..563e85c
--- /dev/null
+++ b/internal/pathtree.go
@@ -0,0 +1,275 @@
+package internal
+
+import (
+ "sort"
+ "strings"
+)
+
+type PathTree[V any] struct {
+ tree subtree[V]
+ root *V
+}
+
+func (pt PathTree[V]) Match(path string) (*V, map[string]string) {
+ path = strings.TrimPrefix(path, "/")
+ if path == "" {
+ return pt.root, map[string]string{}
+ }
+ m := pt.tree.Match(strings.Split(path, "/"), nil)
+ if m == nil {
+ return nil, nil
+ }
+
+ return m.value, m.params()
+}
+
+func (pt *PathTree[V]) Add(pattern string, value V) {
+ pattern = strings.TrimPrefix(pattern, "/")
+ if pattern == "" {
+ pt.root = &value
+ } else {
+ pt.tree.Add(strings.Split(pattern, "/"), value)
+ }
+}
+
+// pattern segment which must be a specific string ("/users/").
+type segmentNode[V any] struct {
+ label string
+ value *V
+ subtree subtree[V]
+}
+
+func makeSegment[V any](pattern []string, value V) segmentNode[V] {
+ node := segmentNode[V]{label: pattern[0]}
+ if len(pattern) == 1 {
+ node.value = &value
+ } else {
+ node.subtree.Add(pattern[1:], value)
+ }
+ return node
+}
+
+func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] {
+ var l int
+ if m != nil {
+ l = m.length
+ }
+ m = &match[V]{
+ value: sn.value,
+ length: l + len(sn.label),
+ prev: m,
+ }
+ if len(segments) == 1 {
+ return m
+ }
+ return sn.subtree.Match(segments[1:], m)
+}
+
+func (sn *segmentNode[V]) Add(pattern []string, value V) {
+ if len(pattern) == 0 {
+ if sn.value != nil {
+ panic("pattern already exists")
+ }
+ sn.value = &value
+ return
+ }
+ sn.subtree.Add(pattern, value)
+}
+
+// single-segment param-capturing wildcard ("/:username/")
+type wildcardNode[V any] struct {
+ param string
+ value *V
+ subtree subtree[V]
+}
+
+func makeWildcard[V any](pattern []string, value V) wildcardNode[V] {
+ node := wildcardNode[V]{param: pattern[0][1:]}
+ if len(pattern) == 1 {
+ node.value = &value
+ } else {
+ node.subtree.Add(pattern[1:], value)
+ }
+ return node
+}
+
+func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] {
+ var l int
+ if m != nil {
+ l = m.length
+ }
+
+ m = &match[V]{
+ value: wn.value,
+ length: l + len(segments[0]),
+ prev: m,
+ }
+ if wn.param != "" {
+ m.paramkey = wn.param
+ m.paramval = segments[0]
+ }
+
+ if len(segments) == 1 {
+ return m
+ }
+
+ return wn.subtree.Match(segments[1:], m)
+}
+
+func (wn *wildcardNode[V]) Add(pattern []string, value V) {
+ if len(pattern) == 0 {
+ if wn.value != nil {
+ panic("pattern already exists")
+ }
+ wn.value = &value
+ return
+ }
+ wn.subtree.Add(pattern, value)
+}
+
+// "rest of the path" capturing node ("/*path_info") - always a final node
+type remainderNode[V any] struct {
+ param string
+ value V
+}
+
+func (rn remainderNode[V]) Match(segments []string, m *match[V]) *match[V] {
+ m = &match[V]{
+ value: &rn.value,
+ length: m.length,
+ prev: m,
+ }
+ if rn.param != "" {
+ m.paramkey = rn.param
+ m.paramval = strings.Join(segments, "/")
+ }
+ return m
+}
+
+// all the children under a tree node
+type subtree[V any] struct {
+ segments childSegments[V]
+ wildcards childWildcards[V]
+ remainder *remainderNode[V]
+}
+
+func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] {
+ var best *match[V]
+
+ if st.remainder != nil {
+ best = st.remainder.Match(segments, m)
+ }
+
+ for _, wc := range st.wildcards {
+ candidate := wc.Match(segments, m)
+ if best == nil || candidate.length > best.length {
+ best = candidate
+ }
+ }
+
+ childSeg := st.segments.Find(segments[0])
+ if childSeg == nil {
+ return best
+ }
+ candidate := childSeg.Match(segments, m)
+ if best == nil || (candidate != nil && candidate.length >= best.length) {
+ best = candidate
+ }
+
+ return best
+}
+
+func (st *subtree[V]) Add(pattern []string, value V) {
+ if len(pattern[0]) == 0 {
+ panic("invalid pattern")
+ }
+
+ switch pattern[0][0] {
+ case '*':
+ if len(pattern) > 1 {
+ panic("invalid pattern: segments after *remainder")
+ }
+ if st.remainder != nil {
+ panic("pattern already exists")
+ }
+
+ st.remainder = &remainderNode[V]{
+ param: pattern[0][1:],
+ value: value,
+ }
+ case ':':
+ child := st.wildcards.Find(pattern[0][1:])
+ if child != nil {
+ child.Add(pattern[1:], value)
+ } else {
+ st.wildcards = append(st.wildcards, makeWildcard(pattern, value))
+ sort.Sort(st.wildcards)
+ }
+ default:
+ child := st.segments.Find(pattern[0])
+ if child != nil {
+ child.Add(pattern[1:], value)
+ } else {
+ st.segments = append(st.segments, makeSegment(pattern, value))
+ sort.Sort(st.segments)
+ }
+ }
+}
+
+type childSegments[V any] []segmentNode[V]
+
+func (cs childSegments[V]) Len() int { return len(cs) }
+func (cs childSegments[V]) Less(i, j int) bool { return cs[i].label < cs[j].label }
+func (cs childSegments[V]) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
+
+func (cs childSegments[V]) Find(label string) *segmentNode[V] {
+ idx := sort.Search(len(cs), func(i int) bool {
+ return label <= cs[i].label
+ })
+ if idx < len(cs) && cs[idx].label == label {
+ return &cs[idx]
+ }
+ return nil
+}
+
+type childWildcards[V any] []wildcardNode[V]
+
+func (cw childWildcards[V]) Len() int { return len(cw) }
+func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param }
+func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] }
+
+func (cw childWildcards[V]) Find(param string) *wildcardNode[V] {
+ i := sort.Search(len(cw), func(i int) bool {
+ return param <= cw[i].param
+ })
+ if i < len(cw) && cw[i].param == param {
+ return &cw[i]
+ }
+ return nil
+}
+
+// linked list we build up as we match our way through the path segments
+// - if there is a parameter captured it goes in paramkey/paramval
+// - if there was a value it goes in value
+// - also store cumulative length to choose the longest match
+type match[V any] struct {
+ paramkey string
+ paramval string
+
+ value *V
+ length int
+ prev *match[V]
+}
+
+func (m match[_]) params() map[string]string {
+ mch := &m
+ out := make(map[string]string)
+ for mch != nil {
+ if mch.paramkey != "" {
+ out[mch.paramkey] = mch.paramval
+ }
+ mch = mch.prev
+ }
+
+ return out
+}
diff --git a/internal/pathtree_test.go b/internal/pathtree_test.go
new file mode 100644
index 0000000..11f6848
--- /dev/null
+++ b/internal/pathtree_test.go
@@ -0,0 +1,103 @@
+package internal_test
+
+import (
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "tildegit.org/tjp/gus/internal"
+)
+
+func TestPathTree(t *testing.T) {
+ type pattern struct {
+ string
+ int
+ }
+ type matchresult struct {
+ value int
+ params map[string]string
+ failed bool
+ }
+ tests := []struct {
+ // path-matching pattern and the integers the tree stores for each
+ patterns []pattern
+
+ // paths to match against, and the int we get and captured params
+ paths map[string]matchresult
+ }{
+ {
+ patterns: []pattern{
+ {"/a", 1},
+ {"/a/*rest", 2},
+ {"/a/b", 3},
+ {"/c", 4},
+ {"/x/:y/z/*rest", 5},
+ },
+ paths: map[string]matchresult{
+ "/a": {
+ value: 1,
+ params: map[string]string{},
+ },
+ "/a/other": {
+ value: 2,
+ params: map[string]string{"rest": "other"},
+ },
+ "/a/b": {
+ value: 3,
+ params: map[string]string{},
+ },
+ "/a/b/c": {
+ value: 2,
+ params: map[string]string{"rest": "b/c"},
+ },
+ "/c": {
+ value: 4,
+ params: map[string]string{},
+ },
+ "/c/d": {
+ failed: true,
+ },
+ "/x/foo/z/bar/baz": {
+ value: 5,
+ params: map[string]string{"y": "foo", "rest": "bar/baz"},
+ },
+ "/": {
+ failed: true,
+ },
+ },
+ },
+ {
+ patterns: []pattern{
+ {"/", 10},
+ },
+ paths: map[string]matchresult{
+ "/": {value: 10, params: map[string]string{}},
+ "/foo": {failed: true},
+ },
+ },
+ }
+
+ for i, test := range tests {
+ t.Run(strconv.Itoa(i+1), func(t *testing.T) {
+ tree := &internal.PathTree[int]{}
+ for _, pattern := range test.patterns {
+ tree.Add(pattern.string, pattern.int)
+ }
+
+ for path, result := range test.paths {
+ t.Run(path, func(t *testing.T) {
+ n, params := tree.Match(path)
+ if result.failed {
+ require.Nil(t, n)
+ } else {
+ require.NotNil(t, n)
+ assert.Equal(t, result.value, *n)
+ assert.Equal(t, result.params, params)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/internal/server.go b/internal/server.go
index 38e478c..3efdf6e 100644
--- a/internal/server.go
+++ b/internal/server.go
@@ -4,17 +4,19 @@ import (
"context"
"net"
"sync"
-
- "tildegit.org/tjp/gus/logging"
)
+type logger interface {
+ Log(keyvals ...any) error
+}
+
type Server struct {
Ctx context.Context
Cancel context.CancelFunc
Wg *sync.WaitGroup
Listener net.Listener
HandleConn connHandler
- ErrorLog logging.Logger
+ ErrorLog logger
Host string
NetworkAddr net.Addr
}
@@ -26,7 +28,7 @@ func NewServer(
hostname string,
network string,
address string,
- errorLog logging.Logger,
+ errorLog logger,
handleConn connHandler,
) (Server, error) {
listener, err := net.Listen(network, address)
diff --git a/router.go b/router.go
new file mode 100644
index 0000000..8408246
--- /dev/null
+++ b/router.go
@@ -0,0 +1,71 @@
+package gus
+
+import (
+ "context"
+
+ "tildegit.org/tjp/gus/internal"
+)
+
+// Router stores a mapping of request path patterns to handlers.
+//
+// Pattern may begin with "/" and then contain slash-delimited segments.
+// - Segments beginning with colon (:) are wildcards and will match any path
+// segment at that location. It may optionally have a word after the colon,
+// which will be the parameter name the path segment is captured into.
+// - Segments beginning with asterisk (*) are remainder wildcards. This must
+// come last and will capture any remainder of the path. It may have a name
+// after the asterisk which will be the parameter name.
+// - Any other segment in the pattern must match a path segment exactly.
+//
+// These patterns do not match any path which shares a prefix, rather then
+// full path must match a pattern. If you want to only match a prefix of the
+// path you can end the pattern with a *remainder segment.
+//
+// The zero value is a usable Router which will fail to match any requst path.
+type Router struct {
+ tree internal.PathTree[Handler]
+}
+
+// Route adds a handler to the router under a path pattern.
+func (r Router) Route(pattern string, handler Handler) {
+ r.tree.Add(pattern, handler)
+}
+
+// Handler matches against the request path and dipatches to a route handler.
+//
+// If no route matches, it returns a nil response.
+// Captured path parameters will be stored in the context passed into the handler
+// and can be retrieved with RouteParams().
+func (r Router) Handler(ctx context.Context, request *Request) *Response {
+ handler, params := r.Match(request)
+ if handler == nil {
+ return nil
+ }
+
+ return handler(context.WithValue(ctx, routeParamsKey, params), request)
+}
+
+// Match returns the matched handler and captured path parameters, or nils.
+func (r Router) Match(request *Request) (Handler, map[string]string) {
+ handler, params := r.tree.Match(request.Path)
+ if handler == nil {
+ return nil, nil
+ }
+ return *handler, params
+}
+
+// RouteParams gathers captured path parameters from the request context.
+//
+// If the context doesn't contain a parameter map, it returns nil.
+// If Router was used but no parameters were captured in the pattern, it
+// returns a non-nil empty map.
+func RouteParams(ctx context.Context) map[string]string {
+ if m, ok := ctx.Value(routeParamsKey).(map[string]string); ok {
+ return m
+ }
+ return nil
+}
+
+type routeParamsKeyType struct{}
+
+var routeParamsKey = routeParamsKeyType{}