package sliderule

import (
	"crypto/tls"
	"errors"
	"fmt"
	"net/http"
	neturl "net/url"

	"tildegit.org/tjp/sliderule/finger"
	"tildegit.org/tjp/sliderule/gemini"
	"tildegit.org/tjp/sliderule/gopher"
	"tildegit.org/tjp/sliderule/internal/types"
	"tildegit.org/tjp/sliderule/spartan"
)

type protocolClient interface {
	RoundTrip(*Request) (*Response, error)
	IsRedirect(*Response) bool
}

// Client is a multi-protocol client which handles all protocols known to sliderule.
type Client struct {
	MaxRedirects int

	protos map[string]protocolClient
}

const DefaultMaxRedirects int = 5

var ExceededMaxRedirects = errors.New("Client: exceeded MaxRedirects")

// NewClient builds a Client object.
//
// tlsConf may be nil, in which case gemini requests connections will not be made
// with any client certificate.
func NewClient(tlsConf *tls.Config) Client {
	hc := httpClient{tp: http.DefaultTransport.(*http.Transport).Clone()}
	if tlsConf != nil {
		hc.tp.TLSClientConfig = tlsConf
	}
	return Client{
		protos: map[string]protocolClient{
			"finger":  finger.Client{},
			"gopher":  gopher.Client{},
			"gemini":  gemini.NewClient(tlsConf),
			"spartan": spartan.NewClient(),
			"http":    hc,
			"https":   hc,
		},
		MaxRedirects: DefaultMaxRedirects,
	}
}

// RoundTrip sends a single request and returns the repsonse.
//
// If the response is a redirect it will be returned, rather than fetched.
func (c Client) RoundTrip(request *Request) (*Response, error) {
	pc, ok := c.protos[request.Scheme]
	if !ok {
		return nil, fmt.Errorf("unrecognized protocol: %s", request.Scheme)
	}
	return pc.RoundTrip(request)
}

// Fetch collects a resource from a URL including following any redirects.
func (c Client) Fetch(url string) (*Response, error) {
	u, err := neturl.Parse(url)
	if err != nil {
		return nil, err
	}

	for i := 0; i <= c.MaxRedirects; i += 1 {
		response, err := c.RoundTrip(&types.Request{URL: u})
		if err != nil {
			return nil, err
		}

		if !c.protos[u.Scheme].IsRedirect(response) {
			return response, nil
		}

		prev := u
		u, err = neturl.Parse(getRedirectLocation(u.Scheme, response.Meta))
		if err != nil {
			return nil, err
		}
		if u.Scheme == "" {
			u.Scheme = prev.Scheme
		}
	}

	return nil, ExceededMaxRedirects
}

func getRedirectLocation(proto string, meta any) string {
	switch proto {
	case "gemini", "spartan":
		return meta.(string)
	case "http", "https":
		return meta.(*http.Request).Header.Get("Location")
	}
	return ""
}

type httpClient struct{
	tp *http.Transport
}

func (hc httpClient) RoundTrip(request *Request) (*Response, error) {
	hreq, err := http.NewRequest("GET", request.URL.String(), nil)
	if err != nil {
		return nil, err
	}

	hresp, err := hc.tp.RoundTrip(hreq)
	if err != nil {
		return nil, err
	}

	return &Response{
		Status: Status(hresp.StatusCode),
		Meta:   hresp,
		Body:   hresp.Body,
	}, nil
}

func (hc httpClient) IsRedirect(response *Response) bool {
	return response.Meta.(*http.Request).Header.Get("Location") != ""
}