// GoToSocial // Copyright (C) GoToSocial Authors admin@gotosocial.org // SPDX-License-Identifier: AGPL-3.0-or-later // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . /* The code in this file is adapted from MIT-licensed code in github.com/go-chi/chi. Thanks chi (thi)! See: https://github.com/go-chi/chi/blob/e6baba61759b26ddf7b14d1e02d1da81a4d76c08/middleware/throttle.go And: https://github.com/sponsors/pkieltyka */ package middleware_test import ( "context" "net/http" "net/http/httptest" "runtime" "strconv" "testing" "time" "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/middleware" ) func TestThrottlingMiddleware(t *testing.T) { testThrottlingMiddleware(t, 2, time.Second*10) testThrottlingMiddleware(t, 4, time.Second*15) testThrottlingMiddleware(t, 8, time.Second*30) } func testThrottlingMiddleware(t *testing.T, cpuMulti int, retryAfter time.Duration) { // Calculate expected request limit + queue. limit := runtime.GOMAXPROCS(0) * cpuMulti queueLimit := limit * cpuMulti // Calculate expected retry-after header string. retryAfterStr := strconv.FormatUint(uint64(retryAfter/time.Second), 10) // Gin test http engine // (used for ctx init). e := gin.New() // Add middleware to the gin engine handler stack. middleware := middleware.Throttle(cpuMulti, retryAfter) e.Use(middleware) // Set the blocking gin handler. handler := blockingHandler() e.Handle("GET", "/", handler) var cncls []func() for i := 0; i < queueLimit+limit; i++ { // Prepare a gin test context. r := httptest.NewRequest("GET", "/", nil) rw := httptest.NewRecorder() // Wrap request with new cancel context. ctx, cncl := context.WithCancel(r.Context()) r = r.WithContext(ctx) // Pass req through // engine handler. go e.ServeHTTP(rw, r) time.Sleep(time.Millisecond) // Get http result. res := rw.Result() if i < queueLimit { // Check status == 200 (default, i.e not set). if res.StatusCode != http.StatusOK { t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i) } // Add cancel to func slice. cncls = append(cncls, cncl) } else { // Check the returned status code is expected. if res.StatusCode != http.StatusTooManyRequests { t.Fatalf("did not return status 429 (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i) } // Check the returned retry-after header is set. if res.Header.Get("Retry-After") != retryAfterStr { t.Fatalf("did not return retry-after %s with queueLimit=%d and request=%d", retryAfterStr, queueLimit, i) } // Cancel on return. defer cncl() } } // Cancel all blocked reqs. for _, cncl := range cncls { cncl() } time.Sleep(time.Second) // Check a bunchh more requests // can now make it through after // previous requests were released! for i := 0; i < limit; i++ { // Prepare a gin test context. r := httptest.NewRequest("GET", "/", nil) rw := httptest.NewRecorder() // Pass req through // engine handler. go e.ServeHTTP(rw, r) time.Sleep(5 * time.Millisecond) // Get http result. res := rw.Result() // Check status == 200 (default, i.e not set). if res.StatusCode != http.StatusOK { t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i) } } } func blockingHandler() gin.HandlerFunc { return func(ctx *gin.Context) { <-ctx.Done() ctx.Status(201) // specifically not 200 } }