summaryrefslogtreecommitdiff
path: root/gemini/client.go
diff options
context:
space:
mode:
authortjp <tjp@ctrl-c.club>2024-01-13 11:29:17 -0700
committertjp <tjp@ctrl-c.club>2024-01-13 11:29:17 -0700
commit4d861a2c395c0926b066014b92999d4dda454b2b (patch)
tree7bbf7d94cec2a5412a68b830c1e5223918d9bc82 /gemini/client.go
parentde1490808fa6e4d6749ff29d20cc1a589ec476d1 (diff)
dial timeouts for clients, and catch up on test fixes
Diffstat (limited to 'gemini/client.go')
-rw-r--r--gemini/client.go11
1 files changed, 6 insertions, 5 deletions
diff --git a/gemini/client.go b/gemini/client.go
index 00e28f6..1e65a39 100644
--- a/gemini/client.go
+++ b/gemini/client.go
@@ -2,6 +2,7 @@ package gemini
import (
"bytes"
+ "context"
"crypto/tls"
"errors"
"io"
@@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects")
//
// This method will not automatically follow redirects or cache permanent failures or
// redirects.
-func (client Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" {
return nil, errors.New("non-gemini protocols not supported")
}
@@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
tlsConf = &tls.Config{InsecureSkipVerify: true}
}
- conn, err := tls.Dial("tcp", host, tlsConf)
+ conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
defer conn.Close()
request.RemoteAddr = conn.RemoteAddr()
- st := conn.ConnectionState()
+ st := conn.(*tls.Conn).ConnectionState()
request.TLSState = &st
destURL := *request.URL
@@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
// Fetch parses a URL string and fetches the gemini resource.
//
// It will resolve any redirects along the way, up to client.MaxRedirects.
-func (c Client) Fetch(url string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
- response, err := c.RoundTrip(&types.Request{URL: u})
+ response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}