package internal

import (
	"sort"
	"strings"
)

type PathTree[V any] struct {
	tree subtree[V]
	root *V
}

func (pt PathTree[V]) Match(path string) (*V, map[string]string) {
	path = strings.TrimPrefix(path, "/")
	if path == "" && pt.root != nil {
		return pt.root, map[string]string{}
	}
	m := pt.tree.Match(strings.Split(path, "/"), nil)
	if m == nil {
		return nil, nil
	}

	return m.value, m.params()
}

func (pt *PathTree[V]) Add(pattern string, value V) {
	pattern = strings.TrimPrefix(pattern, "/")
	if pattern == "" {
		pt.root = &value
	} else {
		pt.tree.Add(strings.Split(pattern, "/"), value)
	}
}

type Route[V any] struct {
	Pattern string
	Value   V
}

func (pt PathTree[V]) Routes() []Route[V] {
	routes := pt.tree.routes()
	if pt.root != nil {
		routes = append([]Route[V]{{Pattern: "", Value: *pt.root}}, routes...)
	}
	return routes
}

// pattern segment which must be a specific string ("/users/").
type segmentNode[V any] struct {
	label   string
	value   *V
	subtree subtree[V]
}

func makeSegment[V any](pattern []string, value V) segmentNode[V] {
	node := segmentNode[V]{label: pattern[0]}
	if len(pattern) == 1 {
		node.value = &value
	} else {
		node.subtree.Add(pattern[1:], value)
	}
	return node
}

func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] {
	var l int
	if m != nil {
		l = m.length
	}
	m = &match[V]{
		value:  sn.value,
		length: l + len(sn.label),
		prev:   m,
	}
	if len(segments) == 1 {
		return m
	}
	return sn.subtree.Match(segments[1:], m)
}

func (sn *segmentNode[V]) Add(pattern []string, value V) {
	if len(pattern) == 0 {
		sn.value = &value
		return
	}
	sn.subtree.Add(pattern, value)
}

// single-segment param-capturing wildcard ("/:username/")
type wildcardNode[V any] struct {
	prefix  string
	param   string
	value   *V
	subtree subtree[V]
}

func makeWildcard[V any](pattern []string, value V) wildcardNode[V] {
	prefix, param, _ := strings.Cut(pattern[0], ":")
	node := wildcardNode[V]{prefix: prefix, param: param}
	if len(pattern) == 1 {
		node.value = &value
	} else {
		node.subtree.Add(pattern[1:], value)
	}
	return node
}

func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] {
	var l int
	if m != nil {
		l = m.length
	}

	m = &match[V]{
		value:  wn.value,
		length: l + len(segments[0]),
		prev:   m,
	}
	if wn.param != "" {
		m.paramkey = wn.param
		m.paramval = strings.TrimPrefix(segments[0], wn.prefix)
	}

	if len(segments) == 1 {
		return m
	}

	return wn.subtree.Match(segments[1:], m)
}

func (wn *wildcardNode[V]) Add(pattern []string, value V) {
	if len(pattern) == 0 {
		wn.value = &value
		return
	}
	wn.subtree.Add(pattern, value)
}

// "rest of the path" capturing node ("/*path_info") - always a final node
type remainderNode[V any] struct {
	param string
	value V
}

func (rn remainderNode[V]) Match(segments []string, m *match[V]) *match[V] {
	length := 0
	if m != nil {
		length = m.length
	}
	m = &match[V]{
		value:  &rn.value,
		length: length,
		prev:   m,
	}
	if rn.param != "" {
		m.paramkey = rn.param
		m.paramval = strings.Join(segments, "/")
	}
	return m
}

// all the children under a tree node
type subtree[V any] struct {
	segments  childSegments[V]
	wildcards childWildcards[V]
	remainder *remainderNode[V]
}

func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] {
	var best *match[V]

	if st.remainder != nil {
		best = st.remainder.Match(segments, m)
	}

	for _, wc := range st.wildcards {
		if wc.prefix != "" && !strings.HasPrefix(segments[0], wc.prefix) {
			continue
		}
		candidate := wc.Match(segments, m)
		if best == nil || candidate.length > best.length {
			best = candidate
		}
	}

	childSeg := st.segments.Find(segments[0])
	if childSeg == nil {
		return best
	}
	candidate := childSeg.Match(segments, m)
	if best == nil || (candidate != nil && candidate.length >= best.length) {
		best = candidate
	}

	return best
}

func (st *subtree[V]) Add(pattern []string, value V) {
	if len(pattern[0]) > 0 {
		switch {
		case pattern[0][0] == '*':
			if len(pattern) > 1 {
				panic("invalid pattern: segments after *remainder")
			}

			st.remainder = &remainderNode[V]{
				param: pattern[0][1:],
				value: value,
			}
			return
		case strings.Contains(pattern[0], ":"):
			prefix, param, _ := strings.Cut(pattern[0], ":")
			child := st.wildcards.Find(prefix, param)
			if child != nil {
				child.Add(pattern[1:], value)
			} else {
				st.wildcards = append(st.wildcards, makeWildcard(pattern, value))
				sort.Sort(st.wildcards)
			}
			return
		}
	}

	child := st.segments.Find(pattern[0])
	if child != nil {
		child.Add(pattern[1:], value)
	} else {
		st.segments = append(st.segments, makeSegment(pattern, value))
		sort.Sort(st.segments)
	}
}

func (st subtree[V]) routes() []Route[V] {
	routes := []Route[V]{}
	for _, seg := range st.segments {
		if seg.value != nil {
			routes = append(routes, Route[V]{
				Pattern: seg.label,
				Value:   *seg.value,
			})
		}
		for _, r := range seg.subtree.routes() {
			r.Pattern = seg.label + "/" + r.Pattern
			routes = append(routes, r)
		}
	}

	for _, wc := range st.wildcards {
		if wc.value != nil {
			routes = append(routes, Route[V]{
				Pattern: ":" + wc.param,
				Value:   *wc.value,
			})
		}
		for _, r := range wc.subtree.routes() {
			r.Pattern = ":" + wc.param + "/" + r.Pattern
			routes = append(routes, r)
		}
	}

	if st.remainder != nil {
		rn := *st.remainder
		routes = append(routes, Route[V]{
			Pattern: "*" + rn.param,
			Value:   rn.value,
		})
	}

	return routes
}

type childSegments[V any] []segmentNode[V]

func (cs childSegments[V]) Len() int           { return len(cs) }
func (cs childSegments[V]) Less(i, j int) bool { return cs[i].label < cs[j].label }
func (cs childSegments[V]) Swap(i, j int)      { cs[i], cs[j] = cs[j], cs[i] }

func (cs childSegments[V]) Find(label string) *segmentNode[V] {
	idx := sort.Search(len(cs), func(i int) bool {
		return label <= cs[i].label
	})
	if idx < len(cs) && cs[idx].label == label {
		return &cs[idx]
	}
	return nil
}

type childWildcards[V any] []wildcardNode[V]

func (cw childWildcards[V]) Len() int           { return len(cw) }
func (cw childWildcards[V]) Less(i, j int) bool { return wildcardLess(cw[i], cw[j]) }
func (cw childWildcards[V]) Swap(i, j int)      { cw[i], cw[j] = cw[j], cw[i] }

func (cw childWildcards[V]) Find(prefix, param string) *wildcardNode[V] {
	i := sort.Search(len(cw), func(i int) bool {
		return wildcardLess(wildcardNode[V]{prefix: prefix, param: param}, cw[i])
	})
	if i < len(cw) && cw[i].prefix == prefix && cw[i].param == param {
		return &cw[i]
	}
	return nil
}

func wildcardLess[V any](a, b wildcardNode[V]) bool {
	switch {
	case a.prefix < b.prefix:
		return true
	case a.prefix == b.prefix:
		return a.param <= b.param
	default:
		return false
	}
}

// linked list we build up as we match our way through the path segments
//   - if there is a parameter captured it goes in paramkey/paramval
//   - if there was a value it goes in value
//   - also store cumulative length to choose the longest match
type match[V any] struct {
	paramkey string
	paramval string

	value  *V
	length int
	prev   *match[V]
}

func (m match[_]) params() map[string]string {
	mch := &m
	out := make(map[string]string)
	for mch != nil {
		if mch.paramkey != "" {
			out[mch.paramkey] = mch.paramval
		}
		mch = mch.prev
	}

	return out
}