From 4d861a2c395c0926b066014b92999d4dda454b2b Mon Sep 17 00:00:00 2001 From: tjp Date: Sat, 13 Jan 2024 11:29:17 -0700 Subject: dial timeouts for clients, and catch up on test fixes --- gemini/client.go | 11 ++++++----- gemini/gemtext/parse_line_test.go | 2 +- gemini/gemtext/parse_test.go | 4 ++-- gemini/roundtrip_test.go | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) (limited to 'gemini') diff --git a/gemini/client.go b/gemini/client.go index 00e28f6..1e65a39 100644 --- a/gemini/client.go +++ b/gemini/client.go @@ -2,6 +2,7 @@ package gemini import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects") // // This method will not automatically follow redirects or cache permanent failures or // redirects. -func (client Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" { return nil, errors.New("non-gemini protocols not supported") } @@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) tlsConf = &tls.Config{InsecureSkipVerify: true} } - conn, err := tls.Dial("tcp", host, tlsConf) + conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host) if err != nil { return nil, err } defer conn.Close() request.RemoteAddr = conn.RemoteAddr() - st := conn.ConnectionState() + st := conn.(*tls.Conn).ConnectionState() request.TLSState = &st destURL := *request.URL @@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) // Fetch parses a URL string and fetches the gemini resource. // // It will resolve any redirects along the way, up to client.MaxRedirects. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } for i := 0; i <= c.MaxRedirects; i += 1 { - response, err := c.RoundTrip(&types.Request{URL: u}) + response, err := c.RoundTrip(ctx, &types.Request{URL: u}) if err != nil { return nil, err } diff --git a/gemini/gemtext/parse_line_test.go b/gemini/gemtext/parse_line_test.go index 9073df0..a7faf49 100644 --- a/gemini/gemtext/parse_line_test.go +++ b/gemini/gemtext/parse_line_test.go @@ -92,7 +92,7 @@ func TestParsePromptLine(t *testing.T) { if line.Type() != gemtext.LineTypePrompt{ t.Errorf("expected LineTypePrompt, got %d", line.Type()) } - link, ok := line.(gemtext.PromptLine) + link, ok := line.(gemtext.LinkLine) if !ok { t.Fatalf("expected a PromptLine, got %T", line) } diff --git a/gemini/gemtext/parse_test.go b/gemini/gemtext/parse_test.go index 90d2c75..4e7b6ff 100644 --- a/gemini/gemtext/parse_test.go +++ b/gemini/gemtext/parse_test.go @@ -78,8 +78,8 @@ This is some non-blank regular text. assert.Equal(t, gemtext.LineTypePrompt, doc[12].Type()) assert.Equal(t, "=: spartan://foo.bar/baz this should be a spartan prompt\n", string(doc[12].Raw())) - assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.PromptLine).URL()) - assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.PromptLine).Label()) + assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.LinkLine).URL()) + assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.LinkLine).Label()) assertEmptyLine(t, doc[13]) diff --git a/gemini/roundtrip_test.go b/gemini/roundtrip_test.go index 50c1962..4a8097f 100644 --- a/gemini/roundtrip_test.go +++ b/gemini/roundtrip_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tildegit.org/tjp/sliderule/internal/types" "tildegit.org/tjp/sliderule/gemini" + "tildegit.org/tjp/sliderule/internal/types" ) func TestRoundTrip(t *testing.T) { @@ -36,7 +36,7 @@ func TestRoundTrip(t *testing.T) { require.Nil(t, err) cli := gemini.NewClient(testClientTLS()) - response, err := cli.RoundTrip(&types.Request{URL: u}) + response, err := cli.RoundTrip(context.Background(), &types.Request{URL: u}) require.Nil(t, err) assert.Equal(t, gemini.StatusSuccess, response.Status) -- cgit v1.2.3