summaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
authortjpcc <tjp@ctrl-c.club>2023-08-25 11:17:11 -0600
committertjpcc <tjp@ctrl-c.club>2023-08-25 11:17:11 -0600
commit4779fe81dd373476b06cb092dc7110e192eda5c3 (patch)
tree7eec1128f810be1bf53d3e6452c351396359c176 /client.go
parentd33ae9da7691a254e08251c375aaa9f6d58631e5 (diff)
support tls client configs in HTTP client
Diffstat (limited to 'client.go')
-rw-r--r--client.go17
1 files changed, 11 insertions, 6 deletions
diff --git a/client.go b/client.go
index 936c2d8..0a25110 100644
--- a/client.go
+++ b/client.go
@@ -35,7 +35,10 @@ var ExceededMaxRedirects = errors.New("Client: exceeded MaxRedirects")
// tlsConf may be nil, in which case gemini requests connections will not be made
// with any client certificate.
func NewClient(tlsConf *tls.Config) Client {
- hc := httpClient{}
+ hc := httpClient{tp: http.DefaultTransport.(*http.Transport).Clone()}
+ if tlsConf != nil {
+ hc.tp.TLSClientConfig = tlsConf
+ }
return Client{
protos: map[string]protocolClient{
"finger": finger.Client{},
@@ -95,12 +98,14 @@ func getRedirectLocation(proto string, meta any) string {
case "gemini", "spartan":
return meta.(string)
case "http", "https":
- return meta.(http.Header).Get("Location")
+ return meta.(*http.Request).Header.Get("Location")
}
return ""
}
-type httpClient struct{}
+type httpClient struct{
+ tp *http.Transport
+}
func (hc httpClient) RoundTrip(request *Request) (*Response, error) {
hreq, err := http.NewRequest("GET", request.URL.String(), nil)
@@ -108,18 +113,18 @@ func (hc httpClient) RoundTrip(request *Request) (*Response, error) {
return nil, err
}
- hresp, err := http.DefaultTransport.RoundTrip(hreq)
+ hresp, err := hc.tp.RoundTrip(hreq)
if err != nil {
return nil, err
}
return &Response{
Status: Status(hresp.StatusCode),
- Meta: hresp.Header,
+ Meta: hresp,
Body: hresp.Body,
}, nil
}
func (hc httpClient) IsRedirect(response *Response) bool {
- return response.Meta.(http.Header).Get("Location") != ""
+ return response.Meta.(*http.Request).Header.Get("Location") != ""
}