From 2063d01cdb02c7ef26dc6d917e3bca252db5d5a8 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Sun, 21 May 2023 17:59:14 +0100 Subject: [PATCH] [bugfix] Add back removed ValidateRequest() before backoff-retry loop (#1805) * add back removed ValidateRequest() before backoff-retry loop Signed-off-by: kim * include response body in error response log Signed-off-by: kim * improved error response body draining Signed-off-by: kim * add more code commenting Signed-off-by: kim * move new error response logic to gtserror, handle instead in transport.Transport{} impl Signed-off-by: kim * appease ye oh mighty linter Signed-off-by: kim * fix mockhttpclient not setting request in http response Signed-off-by: kim --------- Signed-off-by: kim --- internal/gtserror/error.go | 4 +- internal/gtserror/new.go | 66 +++++++++++++++++++++ internal/gtserror/new_test.go | 91 +++++++++++++++++++++++++++++ internal/gtserror/util.go | 42 +++++++++++++ internal/httpclient/client.go | 21 +++++-- internal/httpclient/validate.go | 62 ++++++++++++++++++++ internal/transport/deliver.go | 4 +- internal/transport/dereference.go | 4 +- internal/transport/derefinstance.go | 11 ++-- internal/transport/derefmedia.go | 4 +- internal/transport/finger.go | 21 ++++--- testrig/transportcontroller.go | 1 + 12 files changed, 299 insertions(+), 32 deletions(-) create mode 100644 internal/gtserror/new.go create mode 100644 internal/gtserror/new_test.go create mode 100644 internal/gtserror/util.go create mode 100644 internal/httpclient/validate.go diff --git a/internal/gtserror/error.go b/internal/gtserror/error.go index 56e546cf1..e68ed7d3b 100644 --- a/internal/gtserror/error.go +++ b/internal/gtserror/error.go @@ -34,8 +34,8 @@ const ( notFoundKey errorTypeKey - // error types - TypeSMTP ErrorType = "smtp" // smtp (mail) error + // Types returnable from Type(...). + TypeSMTP ErrorType = "smtp" // smtp (mail) ) // StatusCode checks error for a stored status code value. For example diff --git a/internal/gtserror/new.go b/internal/gtserror/new.go new file mode 100644 index 000000000..ad20e5cac --- /dev/null +++ b/internal/gtserror/new.go @@ -0,0 +1,66 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtserror + +import ( + "errors" + "net/http" + + "codeberg.org/gruf/go-byteutil" +) + +// NewResponseError crafts an error from provided HTTP response +// including the method, status and body (if any provided). This +// will also wrap the returned error using WithStatusCode(). +func NewResponseError(rsp *http.Response) error { + var buf byteutil.Buffer + + // Get URL string ahead of time. + urlStr := rsp.Request.URL.String() + + // Alloc guesstimate of required buf size. + buf.Guarantee(0 + + len(rsp.Request.Method) + + 12 + // request to + len(urlStr) + + 17 + // failed: status=" + len(rsp.Status) + + 8 + // " body=" + 256 + // max body size + 1, // " + ) + + // Build error message string without + // using "fmt", as chances are this will + // be used in a hot code path and we + // know all the incoming types involved. + _, _ = buf.WriteString(rsp.Request.Method) + _, _ = buf.WriteString(" request to ") + _, _ = buf.WriteString(urlStr) + _, _ = buf.WriteString(" failed: status=\"") + _, _ = buf.WriteString(rsp.Status) + _, _ = buf.WriteString("\" body=\"") + _, _ = buf.WriteString(drainBody(rsp.Body, 256)) + _, _ = buf.WriteString("\"") + + // Create new error from msg. + err := errors.New(buf.String()) + + // Wrap error to provide status code. + return WithStatusCode(err, rsp.StatusCode) +} diff --git a/internal/gtserror/new_test.go b/internal/gtserror/new_test.go new file mode 100644 index 000000000..b0824b5a7 --- /dev/null +++ b/internal/gtserror/new_test.go @@ -0,0 +1,91 @@ +package gtserror_test + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" +) + +func TestResponseError(t *testing.T) { + testResponseError(t, http.Response{ + Body: toBody(`{"error": "user not found"}`), + Request: &http.Request{ + Method: "GET", + URL: toURL("https://google.com/users/sundar"), + }, + Status: "404 Not Found", + }) + testResponseError(t, http.Response{ + Body: toBody("Unauthorized"), + Request: &http.Request{ + Method: "POST", + URL: toURL("https://google.com/inbox"), + }, + Status: "401 Unauthorized", + }) + testResponseError(t, http.Response{ + Body: toBody(""), + Request: &http.Request{ + Method: "GET", + URL: toURL("https://google.com/users/sundar"), + }, + Status: "404 Not Found", + }) +} + +func testResponseError(t *testing.T, rsp http.Response) { + var body string + if rsp.Body == http.NoBody { + body = "" + } else { + var b []byte + rsp.Body, b = copyBody(rsp.Body) + trunc := len(b) + if trunc > 256 { + trunc = 256 + } + body = string(b[:trunc]) + } + expect := fmt.Sprintf( + "%s request to %s failed: status=\"%s\" body=\"%s\"", + rsp.Request.Method, + rsp.Request.URL.String(), + rsp.Status, + body, + ) + err := gtserror.NewResponseError(&rsp) + if str := err.Error(); str != expect { + t.Errorf("unexpected error string: recv=%q expct=%q", str, expect) + } +} + +func toURL(u string) *url.URL { + url, err := url.Parse(u) + if err != nil { + panic(err) + } + return url +} + +func toBody(s string) io.ReadCloser { + if s == "" { + return http.NoBody + } + r := strings.NewReader(s) + return io.NopCloser(r) +} + +func copyBody(rc io.ReadCloser) (io.ReadCloser, []byte) { + b, err := io.ReadAll(rc) + if err != nil { + panic(err) + } + r := bytes.NewReader(b) + return io.NopCloser(r), b +} diff --git a/internal/gtserror/util.go b/internal/gtserror/util.go new file mode 100644 index 000000000..635518b76 --- /dev/null +++ b/internal/gtserror/util.go @@ -0,0 +1,42 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtserror + +import ( + "io" + + "codeberg.org/gruf/go-byteutil" +) + +// drainBody will produce a truncated output of the content +// of given io.ReadCloser body, useful for logs / errors. +func drainBody(body io.ReadCloser, trunc int) string { + // Limit response to 'trunc' bytes. + buf := make([]byte, trunc) + + // Read body into err buffer. + n, _ := io.ReadFull(body, buf) + + if n == 0 { + // No error body, return + // reasonable error str. + return "" + } + + return byteutil.B2S(buf[:n]) +} diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index dd1a1bd6b..efbf4cd18 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -41,6 +41,9 @@ import ( ) var ( + // ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed. + ErrInvalidRequest = errors.New("invalid http request") + // ErrInvalidNetwork is returned if the request would not be performed over TCP ErrInvalidNetwork = errors.New("invalid network type") @@ -90,6 +93,9 @@ type Config struct { // cases to protect against forged / unknown content-lengths // - protection from server side request forgery (SSRF) by only dialing // out to known public IP prefixes, configurable with allows/blocks +// - retry-backoff logic for error temporary HTTP error responses +// - optional request signing +// - request logging type Client struct { client http.Client badHosts cache.Cache[string, struct{}] @@ -156,14 +162,14 @@ func New(cfg Config) *Client { return &c } -// Do ... +// Do will essentially perform http.Client{}.Do() with retry-backoff functionality. func (c *Client) Do(r *http.Request) (*http.Response, error) { return c.DoSigned(r, func(r *http.Request) error { return nil // no request signing }) } -// DoSigned ... +// DoSigned will essentially perform http.Client{}.Do() with retry-backoff functionality and requesting signing.. func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, err error) { const ( // max no. attempts. @@ -173,6 +179,11 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e baseBackoff = 2 * time.Second ) + // First validate incoming request. + if err := ValidateRequest(r); err != nil { + return nil, err + } + // Get request hostname. host := r.URL.Hostname() @@ -234,8 +245,8 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e return rsp, nil } - // Generate error from status code for logging - err = errors.New(`http response "` + rsp.Status + `"`) + // Create loggable error from response status code. + err = fmt.Errorf(`http response: %s`, rsp.Status) // Search for a provided "Retry-After" header value. if after := rsp.Header.Get("Retry-After"); after != "" { @@ -307,7 +318,7 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e return } -// do ... +// do wraps http.Client{}.Do() to provide safely limited response bodies. func (c *Client) do(req *http.Request) (*http.Response, error) { // Perform the HTTP request. rsp, err := c.client.Do(req) diff --git a/internal/httpclient/validate.go b/internal/httpclient/validate.go new file mode 100644 index 000000000..881d3f699 --- /dev/null +++ b/internal/httpclient/validate.go @@ -0,0 +1,62 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpclient + +import ( + "fmt" + "net/http" + "strings" + + "golang.org/x/net/http/httpguts" +) + +// ValidateRequest performs the same request validation logic found in the default +// net/http.Transport{}.roundTrip() function, but pulls it out into this separate +// function allowing validation errors to be wrapped under a single error type. +func ValidateRequest(r *http.Request) error { + switch { + case r.URL == nil: + return fmt.Errorf("%w: nil url", ErrInvalidRequest) + case r.Header == nil: + return fmt.Errorf("%w: nil header", ErrInvalidRequest) + case r.URL.Host == "": + return fmt.Errorf("%w: empty url host", ErrInvalidRequest) + case r.URL.Scheme != "http" && r.URL.Scheme != "https": + return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) + case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: + return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method) + } + + for key, values := range r.Header { + // Check field key name is valid + if !httpguts.ValidHeaderFieldName(key) { + return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key) + } + + // Check each field value is valid + for i := 0; i < len(values); i++ { + if !httpguts.ValidHeaderFieldValue(values[i]) { + return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i]) + } + } + } + + // ps. kim wrote this + + return nil +} diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go index fff7dbcf4..054baa6a5 100644 --- a/internal/transport/deliver.go +++ b/internal/transport/deliver.go @@ -19,7 +19,6 @@ package transport import ( "context" - "fmt" "net/http" "net/url" "sync" @@ -131,8 +130,7 @@ func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error { if code := rsp.StatusCode; code != http.StatusOK && code != http.StatusCreated && code != http.StatusAccepted { - err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) - return gtserror.WithStatusCode(err, rsp.StatusCode) + return gtserror.NewResponseError(rsp) } return nil diff --git a/internal/transport/dereference.go b/internal/transport/dereference.go index 71b10a0f1..e231e0954 100644 --- a/internal/transport/dereference.go +++ b/internal/transport/dereference.go @@ -19,7 +19,6 @@ package transport import ( "context" - "fmt" "io" "net/http" "net/url" @@ -66,8 +65,7 @@ func (t *transport) Dereference(ctx context.Context, iri *url.URL) ([]byte, erro defer rsp.Body.Close() if rsp.StatusCode != http.StatusOK { - err := fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status) - return nil, gtserror.WithStatusCode(err, rsp.StatusCode) + return nil, gtserror.NewResponseError(rsp) } return io.ReadAll(rsp.Body) diff --git a/internal/transport/derefinstance.go b/internal/transport/derefinstance.go index 466981348..c373a140a 100644 --- a/internal/transport/derefinstance.go +++ b/internal/transport/derefinstance.go @@ -102,8 +102,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) - return nil, gtserror.WithStatusCode(err, resp.StatusCode) + return nil, gtserror.NewResponseError(resp) } b, err := io.ReadAll(resp.Body) @@ -133,7 +132,7 @@ func dereferenceByAPIV1Instance(ctx context.Context, t *transport, iri *url.URL) ID: ulid, Domain: iri.Host, Title: apiResp.Title, - URI: fmt.Sprintf("%s://%s", iri.Scheme, iri.Host), + URI: iri.Scheme + "://" + iri.Host, ShortDescription: apiResp.ShortDescription, Description: apiResp.Description, ContactEmail: apiResp.Email, @@ -253,8 +252,7 @@ func callNodeInfoWellKnown(ctx context.Context, t *transport, iri *url.URL) (*ur defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) - return nil, gtserror.WithStatusCode(err, resp.StatusCode) + return nil, gtserror.NewResponseError(resp) } b, err := io.ReadAll(resp.Body) @@ -305,8 +303,7 @@ func callNodeInfo(ctx context.Context, t *transport, iri *url.URL) (*apimodel.No defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("GET request to %s failed: %s", iriStr, resp.Status) - return nil, gtserror.WithStatusCode(err, resp.StatusCode) + return nil, gtserror.NewResponseError(resp) } b, err := io.ReadAll(resp.Body) diff --git a/internal/transport/derefmedia.go b/internal/transport/derefmedia.go index 2d9096493..ad47d99b5 100644 --- a/internal/transport/derefmedia.go +++ b/internal/transport/derefmedia.go @@ -19,7 +19,6 @@ package transport import ( "context" - "fmt" "io" "net/http" "net/url" @@ -47,8 +46,7 @@ func (t *transport) DereferenceMedia(ctx context.Context, iri *url.URL) (io.Read // Check for an expected status code if rsp.StatusCode != http.StatusOK { - err := fmt.Errorf("GET request to %s failed: %s", iriStr, rsp.Status) - return nil, 0, gtserror.WithStatusCode(err, rsp.StatusCode) + return nil, 0, gtserror.NewResponseError(rsp) } return rsp.Body, rsp.ContentLength, nil diff --git a/internal/transport/finger.go b/internal/transport/finger.go index 18b028a64..e6086747b 100644 --- a/internal/transport/finger.go +++ b/internal/transport/finger.go @@ -27,6 +27,7 @@ import ( apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" ) // webfingerURLFor returns the URL to try a webfinger request against, as @@ -105,14 +106,16 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom // From here on out, we're handling different failure scenarios and // deciding whether we should do a host-meta based fallback or not - if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached { - // In case we got a 5xx, bail out irrespective of if the value - // was cached or not. The target may be broken or be signalling - // us to back-off. - // - // If it's any error but the URL was cached, bail out too - return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status) - } + // Response status codes >= 500 are returned as errors by the wrapped HTTP client. + // + // if (rsp.StatusCode >= 500 && rsp.StatusCode < 600) || cached { + // In case we got a 5xx, bail out irrespective of if the value + // was cached or not. The target may be broken or be signalling + // us to back-off. + // + // If it's any error but the URL was cached, bail out too + // return nil, gtserror.NewResponseError(rsp) + // } // So far we've failed to get a successful response from the expected // webfinger endpoint. Lets try and discover the webfinger endpoint @@ -153,7 +156,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom } // We've reached the end of the line here, both the original request // and our attempt to resolve it through the fallback have failed - return nil, fmt.Errorf("GET request to %s failed: %s", req.URL.String(), rsp.Status) + return nil, gtserror.NewResponseError(rsp) } // Set the URL in cache here, since host-meta told us this should be the diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index b74888934..1c75e1974 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -209,6 +209,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat reader := bytes.NewReader(responseBytes) readCloser := io.NopCloser(reader) return &http.Response{ + Request: req, StatusCode: responseCode, Body: readCloser, ContentLength: int64(responseContentLength),