package internal import ( "sort" "strings" ) type PathTree[V any] struct { tree subtree[V] root *V } func (pt PathTree[V]) Match(path string) (*V, map[string]string) { path = strings.TrimPrefix(path, "/") if path == "" { return pt.root, map[string]string{} } m := pt.tree.Match(strings.Split(path, "/"), nil) if m == nil { return nil, nil } return m.value, m.params() } func (pt *PathTree[V]) Add(pattern string, value V) { pattern = strings.TrimPrefix(pattern, "/") if pattern == "" { pt.root = &value } else { pt.tree.Add(strings.Split(pattern, "/"), value) } } type Route[V any] struct { Pattern string Value V } func (pt PathTree[V]) Routes() []Route[V] { return pt.tree.routes() } // pattern segment which must be a specific string ("/users/"). type segmentNode[V any] struct { label string value *V subtree subtree[V] } func makeSegment[V any](pattern []string, value V) segmentNode[V] { node := segmentNode[V]{label: pattern[0]} if len(pattern) == 1 { node.value = &value } else { node.subtree.Add(pattern[1:], value) } return node } func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] { var l int if m != nil { l = m.length } m = &match[V]{ value: sn.value, length: l + len(sn.label), prev: m, } if len(segments) == 1 { return m } return sn.subtree.Match(segments[1:], m) } func (sn *segmentNode[V]) Add(pattern []string, value V) { if len(pattern) == 0 { if sn.value != nil { panic("pattern already exists") } sn.value = &value return } sn.subtree.Add(pattern, value) } // single-segment param-capturing wildcard ("/:username/") type wildcardNode[V any] struct { param string value *V subtree subtree[V] } func makeWildcard[V any](pattern []string, value V) wildcardNode[V] { node := wildcardNode[V]{param: pattern[0][1:]} if len(pattern) == 1 { node.value = &value } else { node.subtree.Add(pattern[1:], value) } return node } func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { var l int if m != nil { l = m.length } m = &match[V]{ value: wn.value, length: l + len(segments[0]), prev: m, } if wn.param != "" { m.paramkey = wn.param m.paramval = segments[0] } if len(segments) == 1 { return m } return wn.subtree.Match(segments[1:], m) } func (wn *wildcardNode[V]) Add(pattern []string, value V) { if len(pattern) == 0 { if wn.value != nil { panic("pattern already exists") } wn.value = &value return } wn.subtree.Add(pattern, value) } // "rest of the path" capturing node ("/*path_info") - always a final node type remainderNode[V any] struct { param string value V } func (rn remainderNode[V]) Match(segments []string, m *match[V]) *match[V] { m = &match[V]{ value: &rn.value, length: m.length, prev: m, } if rn.param != "" { m.paramkey = rn.param m.paramval = strings.Join(segments, "/") } return m } // all the children under a tree node type subtree[V any] struct { segments childSegments[V] wildcards childWildcards[V] remainder *remainderNode[V] } func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { var best *match[V] if st.remainder != nil { best = st.remainder.Match(segments, m) } for _, wc := range st.wildcards { candidate := wc.Match(segments, m) if best == nil || candidate.length > best.length { best = candidate } } childSeg := st.segments.Find(segments[0]) if childSeg == nil { return best } candidate := childSeg.Match(segments, m) if best == nil || (candidate != nil && candidate.length >= best.length) { best = candidate } return best } func (st *subtree[V]) Add(pattern []string, value V) { if len(pattern[0]) == 0 { panic("invalid pattern") } switch pattern[0][0] { case '*': if len(pattern) > 1 { panic("invalid pattern: segments after *remainder") } if st.remainder != nil { panic("pattern already exists") } st.remainder = &remainderNode[V]{ param: pattern[0][1:], value: value, } case ':': child := st.wildcards.Find(pattern[0][1:]) if child != nil { child.Add(pattern[1:], value) } else { st.wildcards = append(st.wildcards, makeWildcard(pattern, value)) sort.Sort(st.wildcards) } default: child := st.segments.Find(pattern[0]) if child != nil { child.Add(pattern[1:], value) } else { st.segments = append(st.segments, makeSegment(pattern, value)) sort.Sort(st.segments) } } } func (st subtree[V]) routes() []Route[V] { routes := []Route[V]{} for _, seg := range st.segments { if seg.value != nil { routes = append(routes, Route[V]{ Pattern: seg.label, Value: *seg.value, }) } for _, r := range seg.subtree.routes() { r.Pattern = seg.label + "/" + r.Pattern routes = append(routes, r) } } for _, wc := range st.wildcards { if wc.value != nil { routes = append(routes, Route[V]{ Pattern: ":" + wc.param, Value: *wc.value, }) } for _, r := range wc.subtree.routes() { r.Pattern = ":" + wc.param + "/" + r.Pattern routes = append(routes, r) } } if st.remainder != nil { rn := *st.remainder routes = append(routes, Route[V]{ Pattern: "*" + rn.param, Value: rn.value, }) } return routes } type childSegments[V any] []segmentNode[V] func (cs childSegments[V]) Len() int { return len(cs) } func (cs childSegments[V]) Less(i, j int) bool { return cs[i].label < cs[j].label } func (cs childSegments[V]) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } func (cs childSegments[V]) Find(label string) *segmentNode[V] { idx := sort.Search(len(cs), func(i int) bool { return label <= cs[i].label }) if idx < len(cs) && cs[idx].label == label { return &cs[idx] } return nil } type childWildcards[V any] []wildcardNode[V] func (cw childWildcards[V]) Len() int { return len(cw) } func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param } func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] } func (cw childWildcards[V]) Find(param string) *wildcardNode[V] { i := sort.Search(len(cw), func(i int) bool { return param <= cw[i].param }) if i < len(cw) && cw[i].param == param { return &cw[i] } return nil } // linked list we build up as we match our way through the path segments // - if there is a parameter captured it goes in paramkey/paramval // - if there was a value it goes in value // - also store cumulative length to choose the longest match type match[V any] struct { paramkey string paramval string value *V length int prev *match[V] } func (m match[_]) params() map[string]string { mch := &m out := make(map[string]string) for mch != nil { if mch.paramkey != "" { out[mch.paramkey] = mch.paramval } mch = mch.prev } return out }