diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index a4f9f2044..152ae33d7 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -115,6 +115,8 @@ func (c *Caches) Init() {
c.initUserMute()
c.initUserMuteIDs()
c.initWebfinger()
+ c.initWebPushSubscription()
+ c.initWebPushSubscriptionIDs()
c.initVisibility()
c.initStatusesFilterableFields()
}
diff --git a/internal/cache/db.go b/internal/cache/db.go
index aac11236a..c264d5567 100644
--- a/internal/cache/db.go
+++ b/internal/cache/db.go
@@ -252,6 +252,15 @@ type DBCaches struct {
// UserMuteIDs provides access to the user mute IDs database cache.
UserMuteIDs SliceCache[string]
+
+ // VAPIDKeyPair caches the server's VAPID key pair.
+ VAPIDKeyPair atomic.Pointer[gtsmodel.VAPIDKeyPair]
+
+ // WebPushSubscription provides access to the gtsmodel WebPushSubscription database cache.
+ WebPushSubscription StructCache[*gtsmodel.WebPushSubscription]
+
+ // WebPushSubscriptionIDs provides access to the Web Push subscription IDs database cache.
+ WebPushSubscriptionIDs SliceCache[string]
}
// NOTE:
@@ -1509,9 +1518,10 @@ func (c *Caches) initToken() {
{Fields: "Refresh"},
{Fields: "ClientID", Multiple: true},
},
- MaxSize: cap,
- IgnoreErr: ignoreErrors,
- Copy: copyF,
+ MaxSize: cap,
+ IgnoreErr: ignoreErrors,
+ Copy: copyF,
+ Invalidate: c.OnInvalidateToken,
})
}
@@ -1621,3 +1631,40 @@ func (c *Caches) initUserMuteIDs() {
c.DB.UserMuteIDs.Init(0, cap)
}
+
+func (c *Caches) initWebPushSubscription() {
+ cap := calculateResultCacheMax(
+ sizeofWebPushSubscription(), // model in-mem size.
+ config.GetCacheWebPushSubscriptionMemRatio(),
+ )
+
+ log.Infof(nil, "cache size = %d", cap)
+
+ copyF := func(s1 *gtsmodel.WebPushSubscription) *gtsmodel.WebPushSubscription {
+ s2 := new(gtsmodel.WebPushSubscription)
+ *s2 = *s1
+ return s2
+ }
+
+ c.DB.WebPushSubscription.Init(structr.CacheConfig[*gtsmodel.WebPushSubscription]{
+ Indices: []structr.IndexConfig{
+ {Fields: "ID"},
+ {Fields: "TokenID"},
+ {Fields: "AccountID", Multiple: true},
+ },
+ MaxSize: cap,
+ IgnoreErr: ignoreErrors,
+ Invalidate: c.OnInvalidateWebPushSubscription,
+ Copy: copyF,
+ })
+}
+
+func (c *Caches) initWebPushSubscriptionIDs() {
+ cap := calculateSliceCacheMax(
+ config.GetCacheWebPushSubscriptionIDsMemRatio(),
+ )
+
+ log.Infof(nil, "cache size = %d", cap)
+
+ c.DB.WebPushSubscriptionIDs.Init(0, cap)
+}
diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go
index 9b42e88f6..be3eaa735 100644
--- a/internal/cache/invalidate.go
+++ b/internal/cache/invalidate.go
@@ -278,6 +278,11 @@ func (c *Caches) OnInvalidateStatusFave(fave *gtsmodel.StatusFave) {
c.DB.StatusFaveIDs.Invalidate(fave.StatusID)
}
+func (c *Caches) OnInvalidateToken(token *gtsmodel.Token) {
+ // Invalidate token's push subscription.
+ c.DB.WebPushSubscription.Invalidate("ID", token.ID)
+}
+
func (c *Caches) OnInvalidateUser(user *gtsmodel.User) {
// Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID)
@@ -291,3 +296,8 @@ func (c *Caches) OnInvalidateUserMute(mute *gtsmodel.UserMute) {
// Invalidate source account's user mute lists.
c.DB.UserMuteIDs.Invalidate(mute.AccountID)
}
+
+func (c *Caches) OnInvalidateWebPushSubscription(subscription *gtsmodel.WebPushSubscription) {
+ // Invalidate source account's Web Push subscription list.
+ c.DB.WebPushSubscriptionIDs.Invalidate(subscription.AccountID)
+}
diff --git a/internal/cache/size.go b/internal/cache/size.go
index 26f4096ed..abed1e3b6 100644
--- a/internal/cache/size.go
+++ b/internal/cache/size.go
@@ -66,6 +66,14 @@ you'll make society more equitable for all if you're not careful! :hammer_sickle
// be a serialized string of almost any type, so we pick a
// nice serialized key size on the upper end of normal.
sizeofResultKey = 2 * sizeofIDStr
+
+ // exampleWebPushAuth is a Base64-encoded 16-byte random auth secret.
+ // This secret is consumed as Base64 by webpush-go.
+ exampleWebPushAuth = "ZVxqlt5fzVgmSz2aqiA2XQ=="
+
+ // exampleWebPushP256dh is a Base64-encoded DH P-256 public key.
+ // This secret is consumed as Base64 by webpush-go.
+ exampleWebPushP256dh = "OrpejO16gV97uBXew/T0I7YoUv/CX8fz0z4g8RrQ+edXJqQPjX3XVSo2P0HhcCpCOR1+Dzj5LFcK9jYNqX7SBg=="
)
var (
@@ -558,7 +566,7 @@ func sizeofMove() uintptr {
func sizeofNotification() uintptr {
return uintptr(size.Of(>smodel.Notification{
ID: exampleID,
- NotificationType: gtsmodel.NotificationFave,
+ NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: exampleTime,
TargetAccountID: exampleID,
OriginAccountID: exampleID,
@@ -786,3 +794,11 @@ func sizeofUserMute() uintptr {
Notifications: util.Ptr(false),
}))
}
+
+func sizeofWebPushSubscription() uintptr {
+ return uintptr(size.Of(>smodel.WebPushSubscription{
+ TokenID: exampleID,
+ Auth: exampleWebPushAuth,
+ P256dh: exampleWebPushP256dh,
+ }))
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 2e3ad8ec1..d9491740e 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -248,6 +248,8 @@ type CacheConfiguration struct {
UserMuteMemRatio float64 `name:"user-mute-mem-ratio"`
UserMuteIDsMemRatio float64 `name:"user-mute-ids-mem-ratio"`
WebfingerMemRatio float64 `name:"webfinger-mem-ratio"`
+ WebPushSubscriptionMemRatio float64 `name:"web-push-subscription-mem-ratio"`
+ WebPushSubscriptionIDsMemRatio float64 `name:"web-push-subscription-ids-mem-ratio"`
VisibilityMemRatio float64 `name:"visibility-mem-ratio"`
}
diff --git a/internal/config/defaults.go b/internal/config/defaults.go
index 9b45002d0..0b28b9025 100644
--- a/internal/config/defaults.go
+++ b/internal/config/defaults.go
@@ -209,6 +209,8 @@ var Defaults = Configuration{
UserMuteMemRatio: 2,
UserMuteIDsMemRatio: 3,
WebfingerMemRatio: 0.1,
+ WebPushSubscriptionMemRatio: 1,
+ WebPushSubscriptionIDsMemRatio: 1,
VisibilityMemRatio: 2,
},
diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go
index a35622f8e..2c554d87a 100644
--- a/internal/config/helpers.gen.go
+++ b/internal/config/helpers.gen.go
@@ -4162,6 +4162,64 @@ func GetCacheWebfingerMemRatio() float64 { return global.GetCacheWebfingerMemRat
// SetCacheWebfingerMemRatio safely sets the value for global configuration 'Cache.WebfingerMemRatio' field
func SetCacheWebfingerMemRatio(v float64) { global.SetCacheWebfingerMemRatio(v) }
+// GetCacheWebPushSubscriptionMemRatio safely fetches the Configuration value for state's 'Cache.WebPushSubscriptionMemRatio' field
+func (st *ConfigState) GetCacheWebPushSubscriptionMemRatio() (v float64) {
+ st.mutex.RLock()
+ v = st.config.Cache.WebPushSubscriptionMemRatio
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheWebPushSubscriptionMemRatio safely sets the Configuration value for state's 'Cache.WebPushSubscriptionMemRatio' field
+func (st *ConfigState) SetCacheWebPushSubscriptionMemRatio(v float64) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.WebPushSubscriptionMemRatio = v
+ st.reloadToViper()
+}
+
+// CacheWebPushSubscriptionMemRatioFlag returns the flag name for the 'Cache.WebPushSubscriptionMemRatio' field
+func CacheWebPushSubscriptionMemRatioFlag() string { return "cache-web-push-subscription-mem-ratio" }
+
+// GetCacheWebPushSubscriptionMemRatio safely fetches the value for global configuration 'Cache.WebPushSubscriptionMemRatio' field
+func GetCacheWebPushSubscriptionMemRatio() float64 {
+ return global.GetCacheWebPushSubscriptionMemRatio()
+}
+
+// SetCacheWebPushSubscriptionMemRatio safely sets the value for global configuration 'Cache.WebPushSubscriptionMemRatio' field
+func SetCacheWebPushSubscriptionMemRatio(v float64) { global.SetCacheWebPushSubscriptionMemRatio(v) }
+
+// GetCacheWebPushSubscriptionIDsMemRatio safely fetches the Configuration value for state's 'Cache.WebPushSubscriptionIDsMemRatio' field
+func (st *ConfigState) GetCacheWebPushSubscriptionIDsMemRatio() (v float64) {
+ st.mutex.RLock()
+ v = st.config.Cache.WebPushSubscriptionIDsMemRatio
+ st.mutex.RUnlock()
+ return
+}
+
+// SetCacheWebPushSubscriptionIDsMemRatio safely sets the Configuration value for state's 'Cache.WebPushSubscriptionIDsMemRatio' field
+func (st *ConfigState) SetCacheWebPushSubscriptionIDsMemRatio(v float64) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.WebPushSubscriptionIDsMemRatio = v
+ st.reloadToViper()
+}
+
+// CacheWebPushSubscriptionIDsMemRatioFlag returns the flag name for the 'Cache.WebPushSubscriptionIDsMemRatio' field
+func CacheWebPushSubscriptionIDsMemRatioFlag() string {
+ return "cache-web-push-subscription-ids-mem-ratio"
+}
+
+// GetCacheWebPushSubscriptionIDsMemRatio safely fetches the value for global configuration 'Cache.WebPushSubscriptionIDsMemRatio' field
+func GetCacheWebPushSubscriptionIDsMemRatio() float64 {
+ return global.GetCacheWebPushSubscriptionIDsMemRatio()
+}
+
+// SetCacheWebPushSubscriptionIDsMemRatio safely sets the value for global configuration 'Cache.WebPushSubscriptionIDsMemRatio' field
+func SetCacheWebPushSubscriptionIDsMemRatio(v float64) {
+ global.SetCacheWebPushSubscriptionIDsMemRatio(v)
+}
+
// GetCacheVisibilityMemRatio safely fetches the Configuration value for state's 'Cache.VisibilityMemRatio' field
func (st *ConfigState) GetCacheVisibilityMemRatio() (v float64) {
st.mutex.RLock()
diff --git a/internal/db/admin.go b/internal/db/admin.go
index 77fbbe613..1f24c7932 100644
--- a/internal/db/admin.go
+++ b/internal/db/admin.go
@@ -68,14 +68,6 @@ type Admin interface {
// the number of pending sign-ups sitting in the backlog.
CountUnhandledSignups(ctx context.Context) (int, error)
- // GetVAPIDKeyPair retrieves the existing VAPID key pair, if there is one.
- // If there isn't, it returns nil.
- GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error)
-
- // PutVAPIDKeyPair stores a VAPID key pair.
- // This should be called at most once, during server startup.
- PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error
-
/*
ACTION FUNCS
*/
diff --git a/internal/db/application.go b/internal/db/application.go
index b71e593c2..5a4068431 100644
--- a/internal/db/application.go
+++ b/internal/db/application.go
@@ -48,6 +48,9 @@ type Application interface {
// GetAllTokens ...
GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
+ // GetTokenByID ...
+ GetTokenByID(ctx context.Context, id string) (*gtsmodel.Token, error)
+
// GetTokenByCode ...
GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error)
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index 266b351f5..ff398fca5 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -48,9 +48,6 @@ const rsaKeyBits = 2048
type adminDB struct {
db *bun.DB
state *state.State
-
- // Since the VAPID key pair is very small and never written to concurrently, we can cache it here.
- vapidKeyPair *gtsmodel.VAPIDKeyPair
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) {
@@ -445,39 +442,6 @@ func (a *adminDB) CountUnhandledSignups(ctx context.Context) (int, error) {
Count(ctx)
}
-func (a *adminDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) {
- // Look for cached keys.
- if a.vapidKeyPair != nil {
- return a.vapidKeyPair, nil
- }
-
- // Look for previously generated keys in the database.
- if err := a.db.NewSelect().
- Model(a.vapidKeyPair).
- Limit(1).
- Scan(ctx); // nocollapse
- err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, gtserror.Newf("DB error getting VAPID key pair: %w", err)
- }
-
- return a.vapidKeyPair, nil
-}
-
-func (a *adminDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error {
- // Store the keys in the database.
- if _, err := a.db.NewInsert().
- Model(a.vapidKeyPair).
- Exec(ctx); // nocollapse
- err != nil {
- return gtserror.Newf("DB error putting VAPID key pair: %w", err)
- }
-
- // Cache the keys.
- a.vapidKeyPair = vapidKeyPair
-
- return nil
-}
-
/*
ACTION FUNCS
*/
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index cbba499b0..92fc5ea2b 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -174,6 +174,16 @@ func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, er
return tokens, nil
}
+func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmodel.Token, error) {
+ return a.getTokenBy(
+ "ID",
+ func(t *gtsmodel.Token) error {
+ return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("id"), code).Scan(ctx)
+ },
+ code,
+ )
+}
+
func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
return a.getTokenBy(
"Code",
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 70132fe58..c307e0356 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -87,6 +87,7 @@ type DBService struct {
db.Timeline
db.User
db.Tombstone
+ db.WebPush
db.WorkerTask
db *bun.DB
}
@@ -296,6 +297,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ WebPush: &webPushDB{
+ db: db,
+ state: state,
+ },
WorkerTask: &workerTaskDB{
db: db,
},
diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go
index 8e2fb8031..8cc778071 100644
--- a/internal/db/bundb/notification_test.go
+++ b/internal/db/bundb/notification_test.go
@@ -66,7 +66,7 @@ func (suite *NotificationTestSuite) spamNotifs() {
notif := >smodel.Notification{
ID: notifID,
- NotificationType: gtsmodel.NotificationFave,
+ NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: time.Now(),
TargetAccountID: targetAccountID,
OriginAccountID: originAccountID,
diff --git a/internal/db/bundb/webpush.go b/internal/db/bundb/webpush.go
new file mode 100644
index 000000000..bb2ee2ba2
--- /dev/null
+++ b/internal/db/bundb/webpush.go
@@ -0,0 +1,203 @@
+package bundb
+
+import (
+ "context"
+ "errors"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util/xslices"
+ "github.com/uptrace/bun"
+)
+
+type webPushDB struct {
+ db *bun.DB
+ state *state.State
+}
+
+func (w *webPushDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) {
+ // Look for cached keys.
+ vapidKeyPair := w.state.Caches.DB.VAPIDKeyPair.Load()
+ if vapidKeyPair != nil {
+ return vapidKeyPair, nil
+ }
+
+ // Look for previously generated keys in the database.
+ if err := w.db.NewSelect().
+ Model(vapidKeyPair).
+ Limit(1).
+ Scan(ctx); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, err
+ }
+
+ // Cache the keys.
+ w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair)
+
+ return vapidKeyPair, nil
+}
+
+func (w *webPushDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error {
+ // Store the keys in the database.
+ if _, err := w.db.NewInsert().
+ Model(vapidKeyPair).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return err
+ }
+
+ // Cache the keys.
+ w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair)
+
+ return nil
+}
+
+func (w *webPushDB) GetWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) (*gtsmodel.WebPushSubscription, error) {
+ return w.state.Caches.DB.WebPushSubscription.LoadOne(
+ "TokenID",
+ func() (*gtsmodel.WebPushSubscription, error) {
+ var subscription gtsmodel.WebPushSubscription
+ err := w.db.
+ NewSelect().
+ Model(&subscription).
+ Where("? = ?", bun.Ident("token_id"), tokenID).
+ Scan(ctx)
+ return &subscription, err
+ },
+ tokenID,
+ )
+}
+
+func (w *webPushDB) PutWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription) error {
+ return w.state.Caches.DB.WebPushSubscription.Store(subscription, func() error {
+ _, err := w.db.NewInsert().
+ Model(subscription).
+ Exec(ctx)
+ return err
+ })
+}
+
+func (w *webPushDB) UpdateWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription, columns ...string) error {
+ // If we're updating by column, ensure "updated_at" is included.
+ if len(columns) > 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ // Update database.
+ if _, err := w.db.
+ NewUpdate().
+ Model(subscription).
+ Column(columns...).
+ Where("? = ?", bun.Ident("id"), subscription.ID).
+ Exec(ctx); // nocollapse
+ err != nil {
+ return err
+ }
+
+ // Update cache.
+ w.state.Caches.DB.WebPushSubscription.Put(subscription)
+
+ return nil
+}
+
+func (w *webPushDB) DeleteWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) error {
+ // Deleted partial model for cache invalidation.
+ var deleted gtsmodel.WebPushSubscription
+
+ // Delete subscription, returning subset of columns used by invalidation hook.
+ if _, err := w.db.NewDelete().
+ Model(&deleted).
+ Where("? = ?", bun.Ident("token_id"), tokenID).
+ Returning("?", bun.Ident("account_id")).
+ Exec(ctx); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+
+ // Invalidate cached subscription by token ID.
+ w.state.Caches.DB.WebPushSubscription.Invalidate("TokenID", tokenID)
+
+ // Call invalidate hook directly.
+ w.state.Caches.OnInvalidateWebPushSubscription(&deleted)
+
+ return nil
+}
+
+func (w *webPushDB) GetWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.WebPushSubscription, error) {
+ // Fetch IDs of all subscriptions created by this account.
+ subscriptionIDs, err := loadPagedIDs(&w.state.Caches.DB.WebPushSubscriptionIDs, accountID, nil, func() ([]string, error) {
+ // Subscription IDs not in cache. Perform DB query.
+ var subscriptionIDs []string
+ if _, err := w.db.
+ NewSelect().
+ Model((*gtsmodel.WebPushSubscription)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Order("id DESC").
+ Exec(ctx, &subscriptionIDs); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, err
+ }
+ return subscriptionIDs, nil
+ })
+ if len(subscriptionIDs) == 0 {
+ return nil, nil
+ }
+
+ // Get each subscription by ID from the cache or DB.
+ subscriptions, err := w.state.Caches.DB.WebPushSubscription.LoadIDs("ID",
+ subscriptionIDs,
+ func(uncached []string) ([]*gtsmodel.WebPushSubscription, error) {
+ subscriptions := make([]*gtsmodel.WebPushSubscription, 0, len(uncached))
+ if err := w.db.
+ NewSelect().
+ Model(&subscriptions).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); // nocollapse
+ err != nil {
+ return nil, err
+ }
+ return subscriptions, nil
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Put the subscription structs in the same order as the filter IDs.
+ xslices.OrderBy(
+ subscriptions,
+ subscriptionIDs,
+ func(subscription *gtsmodel.WebPushSubscription) string {
+ return subscription.ID
+ },
+ )
+
+ return subscriptions, nil
+}
+
+func (w *webPushDB) DeleteWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) error {
+ // Deleted partial models for cache invalidation.
+ var deleted []*gtsmodel.WebPushSubscription
+
+ // Delete subscriptions, returning subset of columns.
+ if _, err := w.db.NewDelete().
+ Model(&deleted).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Returning("?", bun.Ident("account_id")).
+ Exec(ctx); // nocollapse
+ err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+
+ // Invalidate cached subscriptions by account ID.
+ w.state.Caches.DB.WebPushSubscription.Invalidate("AccountID", accountID)
+
+ // Call invalidate hooks directly in case those entries weren't cached.
+ for _, subscription := range deleted {
+ w.state.Caches.OnInvalidateWebPushSubscription(subscription)
+ }
+
+ return nil
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index c42985912..b7e2b29bd 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -57,5 +57,6 @@ type DB interface {
Timeline
User
Tombstone
+ WebPush
WorkerTask
}
diff --git a/internal/db/webpush.go b/internal/db/webpush.go
new file mode 100644
index 000000000..6752657d7
--- /dev/null
+++ b/internal/db/webpush.go
@@ -0,0 +1,53 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// WebPush contains functions related to Web Push notifications.
+type WebPush interface {
+ // GetVAPIDKeyPair retrieves the server's existing VAPID key pair, if there is one.
+ // If there isn't, it returns nil.
+ GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error)
+
+ // PutVAPIDKeyPair stores the server's VAPID key pair.
+ // This should be called at most once, during server startup.
+ PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error
+
+ // GetWebPushSubscriptionByTokenID retrieves an access token's Web Push subscription, if there is one.
+ GetWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) (*gtsmodel.WebPushSubscription, error)
+
+ // PutWebPushSubscription creates an access token's Web Push subscription.
+ PutWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription) error
+
+ // UpdateWebPushSubscription updates an access token's Web Push subscription.
+ UpdateWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription, columns ...string) error
+
+ // DeleteWebPushSubscriptionByTokenID deletes an access token's Web Push subscription, if there is one.
+ DeleteWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) error
+
+ // GetWebPushSubscriptionsByAccountID retrieves an account's list of Web Push subscriptions.
+ GetWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.WebPushSubscription, error)
+
+ // DeleteWebPushSubscriptionsByAccountID deletes an account's list of Web Push subscriptions.
+ DeleteWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) error
+}
diff --git a/internal/gtsmodel/notification.go b/internal/gtsmodel/notification.go
index 1ef805081..bdaa3f563 100644
--- a/internal/gtsmodel/notification.go
+++ b/internal/gtsmodel/notification.go
@@ -48,13 +48,14 @@ const (
NotificationFollowRequest NotificationType = 2 // NotificationFollowRequest -- someone requested to follow you
NotificationMention NotificationType = 3 // NotificationMention -- someone mentioned you in their status
NotificationReblog NotificationType = 4 // NotificationReblog -- someone boosted one of your statuses
- NotificationFave NotificationType = 5 // NotificationFave -- someone faved/liked one of your statuses
+ NotificationFavourite NotificationType = 5 // NotificationFavourite -- someone faved/liked one of your statuses
NotificationPoll NotificationType = 6 // NotificationPoll -- a poll you voted in or created has ended
NotificationStatus NotificationType = 7 // NotificationStatus -- someone you enabled notifications for has posted a status.
- NotificationSignup NotificationType = 8 // NotificationSignup -- someone has submitted a new account sign-up to the instance.
- NotificationPendingFave NotificationType = 9 // Someone has faved a status of yours, which requires approval by you.
- NotificationPendingReply NotificationType = 10 // Someone has replied to a status of yours, which requires approval by you.
- NotificationPendingReblog NotificationType = 11 // Someone has boosted a status of yours, which requires approval by you.
+ NotificationAdminSignup NotificationType = 8 // NotificationAdminSignup -- someone has submitted a new account sign-up to the instance.
+ NotificationPendingFave NotificationType = 9 // NotificationPendingFave -- Someone has faved a status of yours, which requires approval by you.
+ NotificationPendingReply NotificationType = 10 // NotificationPendingReply -- Someone has replied to a status of yours, which requires approval by you.
+ NotificationPendingReblog NotificationType = 11 // NotificationPendingReblog -- Someone has boosted a status of yours, which requires approval by you.
+ NotificationAdminReport NotificationType = 12 // NotificationAdminReport -- someone has submitted a new report to the instance.
)
// String returns a stringified, frontend API compatible form of NotificationType.
@@ -68,13 +69,13 @@ func (t NotificationType) String() string {
return "mention"
case NotificationReblog:
return "reblog"
- case NotificationFave:
+ case NotificationFavourite:
return "favourite"
case NotificationPoll:
return "poll"
case NotificationStatus:
return "status"
- case NotificationSignup:
+ case NotificationAdminSignup:
return "admin.sign_up"
case NotificationPendingFave:
return "pending.favourite"
@@ -82,6 +83,8 @@ func (t NotificationType) String() string {
return "pending.reply"
case NotificationPendingReblog:
return "pending.reblog"
+ case NotificationAdminReport:
+ return "admin.report"
default:
panic("invalid notification type")
}
@@ -99,19 +102,21 @@ func ParseNotificationType(in string) NotificationType {
case "reblog":
return NotificationReblog
case "favourite":
- return NotificationFave
+ return NotificationFavourite
case "poll":
return NotificationPoll
case "status":
return NotificationStatus
case "admin.sign_up":
- return NotificationSignup
+ return NotificationAdminSignup
case "pending.favourite":
return NotificationPendingFave
case "pending.reply":
return NotificationPendingReply
case "pending.reblog":
return NotificationPendingReblog
+ case "admin.report":
+ return NotificationAdminReport
default:
return NotificationUnknown
}
diff --git a/internal/gtsmodel/vapidkeypair.go b/internal/gtsmodel/vapidkeypair.go
index 85883df45..56b7edda8 100644
--- a/internal/gtsmodel/vapidkeypair.go
+++ b/internal/gtsmodel/vapidkeypair.go
@@ -22,7 +22,7 @@ package gtsmodel
//
// See: https://datatracker.ietf.org/doc/html/rfc8292
type VAPIDKeyPair struct {
- ID int `bun:"pk,notnull"`
- Public string `bun:"notnull,nullzero"`
- Private string `bun:"notnull,nullzero"`
+ ID int `bun:",pk,notnull"`
+ Public string `bun:",notnull,nullzero"`
+ Private string `bun:",notnull,nullzero"`
}
diff --git a/internal/gtsmodel/webpushsubscription.go b/internal/gtsmodel/webpushsubscription.go
new file mode 100644
index 000000000..b14fb1caf
--- /dev/null
+++ b/internal/gtsmodel/webpushsubscription.go
@@ -0,0 +1,67 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package gtsmodel
+
+import (
+ "time"
+)
+
+// WebPushSubscription represents an access token's Web Push subscription.
+// There can be at most one per access token.
+type WebPushSubscription struct {
+ // ID of this subscription in the database.
+ ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`
+
+ // CreatedAt is the time this subscription was created.
+ CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
+
+ // UpdatedAt is the time this subscription was last updated.
+ UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
+
+ // AccountID of the local account that created this subscription.
+ AccountID string `bun:"type:CHAR(26),notnull,nullzero"`
+
+ // TokenID is the ID of the associated access token.
+ // There can be at most one subscription for any given access token,
+ TokenID string `bun:"type:CHAR(26),nullzero,notnull,unique"`
+
+ // Endpoint is the URL receiving Web Push notifications for this subscription.
+ Endpoint string `bun:",nullzero,notnull"`
+
+ // Auth is a Base64-encoded authentication secret.
+ Auth string `bun:",nullzero,notnull"`
+
+ // P256dh is a Base64-encoded Diffie-Hellman public key on the P-256 elliptic curve.
+ P256dh string `bun:",nullzero,notnull"`
+
+ // NotifyFollow and friends control which notifications are delivered to a given subscription.
+ // Corresponds to NotificationType and model.PushSubscriptionAlerts.
+ NotifyFollow *bool `bun:",nullzero,notnull,default:false"`
+ NotifyFollowRequest *bool `bun:",nullzero,notnull,default:false"`
+ NotifyFavourite *bool `bun:",nullzero,notnull,default:false"`
+ NotifyMention *bool `bun:",nullzero,notnull,default:false"`
+ NotifyReblog *bool `bun:",nullzero,notnull,default:false"`
+ NotifyPoll *bool `bun:",nullzero,notnull,default:false"`
+ NotifyStatus *bool `bun:",nullzero,notnull,default:false"`
+ NotifyUpdate *bool `bun:",nullzero,notnull,default:false"`
+ NotifyAdminSignup *bool `bun:",nullzero,notnull,default:false"`
+ NotifyAdminReport *bool `bun:",nullzero,notnull,default:false"`
+ NotifyPendingFave *bool `bun:",nullzero,notnull,default:false"`
+ NotifyPendingReply *bool `bun:",nullzero,notnull,default:false"`
+ NotifyPendingReblog *bool `bun:",nullzero,notnull,default:false"`
+}
diff --git a/internal/processing/timeline/notification.go b/internal/processing/timeline/notification.go
index a242c7b74..09636e7eb 100644
--- a/internal/processing/timeline/notification.go
+++ b/internal/processing/timeline/notification.go
@@ -184,7 +184,7 @@ func (p *Processor) notifVisible(
// If this is a new local account sign-up,
// skip normal visibility checking because
// origin account won't be confirmed yet.
- if n.NotificationType == gtsmodel.NotificationSignup {
+ if n.NotificationType == gtsmodel.NotificationAdminSignup {
return true, nil
}
diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go
index d7d7454e7..ca6fe38f9 100644
--- a/internal/processing/workers/fromfediapi_test.go
+++ b/internal/processing/workers/fromfediapi_test.go
@@ -241,7 +241,7 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
notif := >smodel.Notification{}
err = testStructs.State.DB.GetWhere(context.Background(), where, notif)
suite.NoError(err)
- suite.Equal(gtsmodel.NotificationFave, notif.NotificationType)
+ suite.Equal(gtsmodel.NotificationFavourite, notif.NotificationType)
suite.Equal(fave.TargetAccountID, notif.TargetAccountID)
suite.Equal(fave.AccountID, notif.OriginAccountID)
suite.Equal(fave.StatusID, notif.StatusID)
@@ -314,7 +314,7 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
notif := >smodel.Notification{}
err = testStructs.State.DB.GetWhere(context.Background(), where, notif)
suite.NoError(err)
- suite.Equal(gtsmodel.NotificationFave, notif.NotificationType)
+ suite.Equal(gtsmodel.NotificationFavourite, notif.NotificationType)
suite.Equal(fave.TargetAccountID, notif.TargetAccountID)
suite.Equal(fave.AccountID, notif.OriginAccountID)
suite.Equal(fave.StatusID, notif.StatusID)
diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go
index 1520d2ec0..7773e80d3 100644
--- a/internal/processing/workers/surfacenotify.go
+++ b/internal/processing/workers/surfacenotify.go
@@ -250,7 +250,7 @@ func (s *Surface) notifyFave(
// notify status author
// of fave by account.
if err := s.Notify(ctx,
- gtsmodel.NotificationFave,
+ gtsmodel.NotificationFavourite,
fave.TargetAccount,
fave.Account,
fave.StatusID,
@@ -521,7 +521,7 @@ func (s *Surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) erro
var errs gtserror.MultiError
for _, mod := range modAccounts {
if err := s.Notify(ctx,
- gtsmodel.NotificationSignup,
+ gtsmodel.NotificationAdminSignup,
mod,
newUser.Account,
"",
diff --git a/test/envparsing.sh b/test/envparsing.sh
index 927c5f98b..e5e69a710 100755
--- a/test/envparsing.sh
+++ b/test/envparsing.sh
@@ -75,6 +75,8 @@ EXPECT=$(cat << "EOF"
"user-mute-ids-mem-ratio": 3,
"user-mute-mem-ratio": 2,
"visibility-mem-ratio": 2,
+ "web-push-subscription-ids-mem-ratio": 1,
+ "web-push-subscription-mem-ratio": 1,
"webfinger-mem-ratio": 0.1
},
"config-path": "internal/config/testdata/test.yaml",
diff --git a/testrig/db.go b/testrig/db.go
index 5e423431c..c107b9b05 100644
--- a/testrig/db.go
+++ b/testrig/db.go
@@ -19,6 +19,7 @@ package testrig
import (
"context"
+
webpushgo "github.com/SherClockHolmes/webpush-go"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
@@ -60,6 +61,8 @@ var testModels = []interface{}{
>smodel.ThreadToStatus{},
>smodel.User{},
>smodel.UserMute{},
+ >smodel.VAPIDKeyPair{},
+ >smodel.WebPushSubscription{},
>smodel.Emoji{},
>smodel.Instance{},
>smodel.Notification{},
@@ -347,6 +350,12 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
}
}
+ for _, v := range NewTestWebPushSubscriptions() {
+ if err := db.Put(ctx, v); err != nil {
+ log.Panic(nil, err)
+ }
+ }
+
for _, v := range NewTestInteractionRequests() {
if err := db.Put(ctx, v); err != nil {
log.Panic(nil, err)
diff --git a/testrig/testmodels.go b/testrig/testmodels.go
index ae69b9e81..c9c0c7be5 100644
--- a/testrig/testmodels.go
+++ b/testrig/testmodels.go
@@ -2475,7 +2475,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification {
return map[string]*gtsmodel.Notification{
"local_account_1_like": {
ID: "01F8Q0ANPTWW10DAKTX7BRPBJP",
- NotificationType: gtsmodel.NotificationFave,
+ NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"),
TargetAccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
OriginAccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
@@ -2484,7 +2484,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification {
},
"local_account_2_like": {
ID: "01GTS6PRPXJYZBPFFQ56PP0XR8",
- NotificationType: gtsmodel.NotificationFave,
+ NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: TimeMustParse("2022-01-13T12:45:01+02:00"),
TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
OriginAccountID: "01F8MH5NBDF2MV7CTC4Q5128HF",
@@ -2493,7 +2493,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification {
},
"new_signup": {
ID: "01HTM9TETMB3YQCBKZ7KD4KV02",
- NotificationType: gtsmodel.NotificationSignup,
+ NotificationType: gtsmodel.NotificationAdminSignup,
CreatedAt: TimeMustParse("2022-06-04T13:12:00Z"),
TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
OriginAccountID: "01F8MH0BBE4FHXPH513MBVFHB0",
@@ -3476,6 +3476,32 @@ func NewTestUserMutes() map[string]*gtsmodel.UserMute {
return map[string]*gtsmodel.UserMute{}
}
+func NewTestWebPushSubscriptions() map[string]*gtsmodel.WebPushSubscription {
+ return map[string]*gtsmodel.WebPushSubscription{
+ "local_account_1_token_1": {
+ ID: "01G65Z755AFWAKHE12NY0CQ9FH",
+ AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
+ TokenID: "01F8MGTQW4DKTDF8SW5CT9HYGA",
+ Endpoint: "https://example.test/push",
+ Auth: "cgna/fzrYLDQyPf5hD7IsA==",
+ P256dh: "BMYVItYVOX+AHBdtA62Q0i6c+F7MV2Gia3aoDr8mvHkuPBNIOuTLDfmFcnBqoZcQk6BtLcIONbxhHpy2R+mYIUY=",
+ NotifyFollow: util.Ptr(true),
+ NotifyFollowRequest: util.Ptr(true),
+ NotifyFavourite: util.Ptr(true),
+ NotifyMention: util.Ptr(true),
+ NotifyReblog: util.Ptr(true),
+ NotifyPoll: util.Ptr(true),
+ NotifyStatus: util.Ptr(true),
+ NotifyUpdate: util.Ptr(true),
+ NotifyAdminSignup: util.Ptr(true),
+ NotifyAdminReport: util.Ptr(true),
+ NotifyPendingFave: util.Ptr(true),
+ NotifyPendingReply: util.Ptr(true),
+ NotifyPendingReblog: util.Ptr(true),
+ },
+ }
+}
+
func NewTestInteractionRequests() map[string]*gtsmodel.InteractionRequest {
return map[string]*gtsmodel.InteractionRequest{
"admin_account_reply_turtle": {