diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index e41531a59..900df4383 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -19,6 +19,7 @@ package streaming import ( "context" + "net/http" "slices" "time" @@ -151,15 +152,24 @@ import ( // description: bad request func (m *Module) StreamGETHandler(c *gin.Context) { var ( - account *gtsmodel.Account - errWithCode gtserror.WithCode + token string + tokenInHeader bool + account *gtsmodel.Account + errWithCode gtserror.WithCode ) - // Try query param access token. - token := c.Query(AccessTokenQueryKey) - if token == "" { - // Try fallback HTTP header provided token. - token = c.GetHeader(AccessTokenHeader) + if t := c.Query(AccessTokenQueryKey); t != "" { + // Token was provided as + // query param, no problem. + token = t + } else if t := c.GetHeader(AccessTokenHeader); t != "" { + // Token was provided in "Sec-Websocket-Protocol" header. + // + // This is hacky and not technically correct but some + // clients do it since Mastodon allows it, so we must + // also allow it to avoid breaking expectations. + token = t + tokenInHeader = true } if token != "" { @@ -230,7 +240,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // // If the upgrade fails, then Upgrade replies to the client // with an HTTP error response. - wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) + var responseHeader http.Header + if tokenInHeader { + // Return the token in the response, + // else Chrome fails to connect. + // + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism#sec-websocket-protocol + responseHeader = http.Header{AccessTokenHeader: {token}} + } + + wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, responseHeader) if err != nil { l.Errorf("error upgrading websocket connection: %v", err) stream.Close() diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index 1d94a87ec..acdcafd8a 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -22,7 +22,7 @@ import ( "encoding/base64" "errors" "fmt" - "io/ioutil" + "io" "net" "net/http" "net/http/httptest" @@ -236,7 +236,7 @@ func (suite *StreamingTestSuite) TestSecurityHeader() { result := recorder.Result() defer result.Body.Close() - b, err := ioutil.ReadAll(result.Body) + b, err := io.ReadAll(result.Body) suite.NoError(err) // check response