From 4d861a2c395c0926b066014b92999d4dda454b2b Mon Sep 17 00:00:00 2001 From: tjp Date: Sat, 13 Jan 2024 11:29:17 -0700 Subject: dial timeouts for clients, and catch up on test fixes --- gemini/client.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'gemini/client.go') diff --git a/gemini/client.go b/gemini/client.go index 00e28f6..1e65a39 100644 --- a/gemini/client.go +++ b/gemini/client.go @@ -2,6 +2,7 @@ package gemini import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects") // // This method will not automatically follow redirects or cache permanent failures or // redirects. -func (client Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" { return nil, errors.New("non-gemini protocols not supported") } @@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) tlsConf = &tls.Config{InsecureSkipVerify: true} } - conn, err := tls.Dial("tcp", host, tlsConf) + conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host) if err != nil { return nil, err } defer conn.Close() request.RemoteAddr = conn.RemoteAddr() - st := conn.ConnectionState() + st := conn.(*tls.Conn).ConnectionState() request.TLSState = &st destURL := *request.URL @@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) // Fetch parses a URL string and fetches the gemini resource. // // It will resolve any redirects along the way, up to client.MaxRedirects. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } for i := 0; i <= c.MaxRedirects; i += 1 { - response, err := c.RoundTrip(&types.Request{URL: u}) + response, err := c.RoundTrip(ctx, &types.Request{URL: u}) if err != nil { return nil, err } -- cgit v1.2.3