forked from mirrors/gotosocial
[bugfix] return early in websocket upgrade handler (#1315)
* launch websocket streaming in goroutine to allow upgrade handler to return * don't send any message on ping, improved close check on failed read * use context to signal wsconn close, ensure canceled in read goroutine Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
parent
98edd75f1b
commit
1bda6a2002
4 changed files with 110 additions and 70 deletions
|
@ -19,6 +19,8 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
|
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
|
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
|
||||||
|
@ -122,7 +124,7 @@ func NewClient(db db.DB, p processing.Processor) *Client {
|
||||||
notifications: notifications.New(p),
|
notifications: notifications.New(p),
|
||||||
search: search.New(p),
|
search: search.New(p),
|
||||||
statuses: statuses.New(p),
|
statuses: statuses.New(p),
|
||||||
streaming: streaming.New(p),
|
streaming: streaming.New(p, time.Second*30, 4096),
|
||||||
timelines: timelines.New(p),
|
timelines: timelines.New(p),
|
||||||
user: user.New(p),
|
user: user.New(p),
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,9 @@
|
||||||
package streaming
|
package streaming
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"codeberg.org/gruf/go-kv"
|
"codeberg.org/gruf/go-kv"
|
||||||
|
@ -32,16 +33,6 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
wsUpgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
// we expect cors requests (via eg., pinafore.social) so be lenient
|
|
||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
|
||||||
}
|
|
||||||
errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader)
|
|
||||||
)
|
|
||||||
|
|
||||||
// StreamGETHandler swagger:operation GET /api/v1/streaming streamGet
|
// StreamGETHandler swagger:operation GET /api/v1/streaming streamGet
|
||||||
//
|
//
|
||||||
// Initiate a websocket connection for live streaming of statuses and notifications.
|
// Initiate a websocket connection for live streaming of statuses and notifications.
|
||||||
|
@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var accessToken string
|
var token string
|
||||||
if t := c.Query(AccessTokenQueryKey); t != "" {
|
|
||||||
// try query param first
|
// First we check for a query param provided access token
|
||||||
accessToken = t
|
if token = c.Query(AccessTokenQueryKey); token == "" {
|
||||||
} else if t := c.GetHeader(AccessTokenHeader); t != "" {
|
// Else we check the HTTP header provided token
|
||||||
// fall back to Sec-Websocket-Protocol
|
if token = c.GetHeader(AccessTokenHeader); token == "" {
|
||||||
accessToken = t
|
const errStr = "no access token provided"
|
||||||
} else {
|
err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr)
|
||||||
// no token
|
apiutil.ErrorHandler(c, err, m.processor.InstanceGet)
|
||||||
err := errNoToken
|
return
|
||||||
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken)
|
account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
|
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
|
||||||
return
|
return
|
||||||
|
@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
|
|
||||||
l := log.WithFields(kv.Fields{
|
l := log.WithFields(kv.Fields{
|
||||||
{"account", account.Username},
|
{"account", account.Username},
|
||||||
{"path", BasePath},
|
|
||||||
{"streamID", stream.ID},
|
{"streamID", stream.ID},
|
||||||
{"streamType", streamType},
|
{"streamType", streamType},
|
||||||
}...)
|
}...)
|
||||||
|
|
||||||
wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
|
// Upgrade the incoming HTTP request, which hijacks the underlying
|
||||||
|
// connection and reuses it for the websocket (non-http) protocol.
|
||||||
|
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If the upgrade fails, then Upgrade replies to the client with an HTTP error response.
|
l.Errorf("error upgrading websocket connection: %v", err)
|
||||||
// Because websocket issues are a pretty common source of headaches, we should also log
|
|
||||||
// this at Error to make this plenty visible and help admins out a bit.
|
|
||||||
l.Errorf("error upgrading websocket connection: %s", err)
|
|
||||||
close(stream.Hangup)
|
close(stream.Hangup)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
go func() {
|
||||||
// cleanup
|
// We perform the main websocket send loop in a separate
|
||||||
wsConn.Close()
|
// goroutine in order to let the upgrade handler return.
|
||||||
close(stream.Hangup)
|
// This prevents the upgrade handler from holding open any
|
||||||
}()
|
// throttle / rate-limit request tokens which could become
|
||||||
|
// problematic on instances with multiple users.
|
||||||
|
l.Info("opened websocket connection")
|
||||||
|
defer l.Info("closed websocket connection")
|
||||||
|
|
||||||
streamTicker := time.NewTicker(m.tickDuration)
|
// Create new context for lifetime of the connection
|
||||||
defer streamTicker.Stop()
|
ctx, cncl := context.WithCancel(context.Background())
|
||||||
|
|
||||||
// We want to stay in the loop as long as possible while the client is connected.
|
// Create ticker to send alive pings
|
||||||
// The only thing that should break the loop is if the client leaves or the connection becomes unhealthy.
|
pinger := time.NewTicker(m.dTicker)
|
||||||
//
|
|
||||||
// If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again
|
defer func() {
|
||||||
wsLoop:
|
// Signal done
|
||||||
for {
|
cncl()
|
||||||
select {
|
|
||||||
case m := <-stream.Messages:
|
// Close websocket conn
|
||||||
l.Trace("received message from stream")
|
_ = wsConn.Close()
|
||||||
if err := wsConn.WriteJSON(m); err != nil {
|
|
||||||
l.Debugf("error writing json to websocket connection; breaking off: %s", err)
|
// Close processor stream
|
||||||
break wsLoop
|
close(stream.Hangup)
|
||||||
|
|
||||||
|
// Stop ping ticker
|
||||||
|
pinger.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Signal done
|
||||||
|
defer cncl()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// We have to listen for received websocket messages in
|
||||||
|
// order to trigger the underlying wsConn.PingHandler().
|
||||||
|
//
|
||||||
|
// So we wait on received messages but only act on errors.
|
||||||
|
_, _, err := wsConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
// Only log error if the connection was not closed
|
||||||
|
// by us. Uncanceled context indicates this is the case.
|
||||||
|
l.Errorf("error reading from websocket: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
l.Trace("wrote message into websocket connection")
|
}()
|
||||||
case <-streamTicker.C:
|
|
||||||
l.Trace("received TICK from ticker")
|
for {
|
||||||
if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil {
|
select {
|
||||||
l.Debugf("error writing ping to websocket connection; breaking off: %s", err)
|
// Connection closed
|
||||||
break wsLoop
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
|
||||||
|
// Received next stream message
|
||||||
|
case msg := <-stream.Messages:
|
||||||
|
l.Tracef("sending message to websocket: %+v", msg)
|
||||||
|
if err := wsConn.WriteJSON(msg); err != nil {
|
||||||
|
l.Errorf("error writing json to websocket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset on each successful send.
|
||||||
|
pinger.Reset(m.dTicker)
|
||||||
|
|
||||||
|
// Send keep-alive "ping"
|
||||||
|
case <-pinger.C:
|
||||||
|
l.Trace("pinging websocket ...")
|
||||||
|
if err := wsConn.WriteMessage(
|
||||||
|
websocket.PingMessage,
|
||||||
|
[]byte{},
|
||||||
|
); err != nil {
|
||||||
|
l.Errorf("error writing ping to websocket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
l.Trace("wrote ping message into websocket connection")
|
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,21 +42,22 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Module struct {
|
type Module struct {
|
||||||
processor processing.Processor
|
processor processing.Processor
|
||||||
tickDuration time.Duration
|
dTicker time.Duration
|
||||||
|
wsUpgrade websocket.Upgrader
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(processor processing.Processor) *Module {
|
func New(processor processing.Processor, dTicker time.Duration, wsBuf int) *Module {
|
||||||
return &Module{
|
return &Module{
|
||||||
processor: processor,
|
processor: processor,
|
||||||
tickDuration: 30 * time.Second,
|
dTicker: dTicker,
|
||||||
}
|
wsUpgrade: websocket.Upgrader{
|
||||||
}
|
ReadBufferSize: wsBuf, // we don't expect reads
|
||||||
|
WriteBufferSize: wsBuf,
|
||||||
|
|
||||||
func NewWithTickDuration(processor processing.Processor, tickDuration time.Duration) *Module {
|
// we expect cors requests (via eg., pinafore.social) so be lenient
|
||||||
return &Module{
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
processor: processor,
|
},
|
||||||
tickDuration: tickDuration,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (suite *StreamingTestSuite) SetupTest() {
|
||||||
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
|
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
|
||||||
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
|
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
|
||||||
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
|
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
|
||||||
suite.streamingModule = streaming.NewWithTickDuration(suite.processor, 1)
|
suite.streamingModule = streaming.New(suite.processor, 1, 4096)
|
||||||
suite.NoError(suite.processor.Start())
|
suite.NoError(suite.processor.Start())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue