mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-12-04 22:36:34 +00:00
a48cce82b9
* [chore] Bump tooling versions, bump go -> v1.23.0 * undo silly change * sign * bump go version in go.mod * allow overflow in imaging * goreleaser deprecation notices * [chore] Upgrade golangci-lint, ignore existing int overflow warnings There is a new lint for unchecked int casts. Integer overflows are bad, but the old code that triggers this lint seems to be perfectly fine. Instead of disabling the lint entirely for new code as well, grandfather in existing code. * fix golangci-lint documentation link * revert unrelated changes * revert another unrelated change * get rid of remaining nolint:gosec * swagger updates * apply review feedback * fix wrong formatting specifier thing * fix the linter for real --------- Co-authored-by: tobi <tobi.smethurst@protonmail.com>
281 lines
6.3 KiB
Go
281 lines
6.3 KiB
Go
// GoToSocial
|
|
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
package domain
|
|
|
|
import (
|
|
"fmt"
|
|
"slices"
|
|
"strings"
|
|
"sync/atomic"
|
|
)
|
|
|
|
// Cache provides a means of caching domains in memory to reduce
|
|
// load on an underlying storage mechanism, e.g. a database.
|
|
//
|
|
// The in-memory domain list is kept up-to-date by means of a passed
|
|
// loader function during every call to .Matches(). In the case of
|
|
// a nil internal domain list, the loader function is called to hydrate
|
|
// the cache with the latest list of domains.
|
|
//
|
|
// The .Clear() function can be used to invalidate the cache,
|
|
// e.g. when an entry is added / deleted from the database.
|
|
type Cache struct {
|
|
// current domain cache radix trie.
|
|
rootptr atomic.Pointer[root]
|
|
}
|
|
|
|
// Matches checks whether domain matches an entry in the cache.
|
|
// If the cache is not currently loaded, then the provided load
|
|
// function is used to hydrate it.
|
|
func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, error) {
|
|
// Load the current
|
|
// root pointer value.
|
|
ptr := c.rootptr.Load()
|
|
|
|
if ptr == nil {
|
|
// Cache is not hydrated.
|
|
//
|
|
// Load domains from callback.
|
|
domains, err := load()
|
|
if err != nil {
|
|
return false, fmt.Errorf("error reloading cache: %w", err)
|
|
}
|
|
|
|
// Ensure the domains being inserted into the cache
|
|
// are sorted by number of domain parts. i.e. those
|
|
// with less parts are inserted last, else this can
|
|
// allow domains to fall through the matching code!
|
|
slices.SortFunc(domains, func(a, b string) int {
|
|
const k = +1
|
|
an := strings.Count(a, ".")
|
|
bn := strings.Count(b, ".")
|
|
switch {
|
|
case an < bn:
|
|
return +k
|
|
case an > bn:
|
|
return -k
|
|
default:
|
|
return 0
|
|
}
|
|
})
|
|
|
|
// Allocate new radix trie
|
|
// node to store matches.
|
|
ptr = new(root)
|
|
|
|
// Add each domain to the trie.
|
|
for _, domain := range domains {
|
|
ptr.Add(domain)
|
|
}
|
|
|
|
// Sort the trie.
|
|
ptr.Sort()
|
|
|
|
// Store new node ptr.
|
|
c.rootptr.Store(ptr)
|
|
}
|
|
|
|
// Look for match in trie node.
|
|
return ptr.Match(domain), nil
|
|
}
|
|
|
|
// Clear will drop the currently loaded domain list,
|
|
// triggering a reload on next call to .Matches().
|
|
func (c *Cache) Clear() { c.rootptr.Store(nil) }
|
|
|
|
// String returns a string representation of stored domains in cache.
|
|
func (c *Cache) String() string {
|
|
if ptr := c.rootptr.Load(); ptr != nil {
|
|
return ptr.String()
|
|
}
|
|
return "<empty>"
|
|
}
|
|
|
|
// root is the root node in the domain cache radix trie. this is the singular access point to the trie.
|
|
type root struct{ root node }
|
|
|
|
// Add will add the given domain to the radix trie.
|
|
func (r *root) Add(domain string) {
|
|
r.root.Add(strings.Split(domain, "."))
|
|
}
|
|
|
|
// Match will return whether the given domain matches
|
|
// an existing stored domain in this radix trie.
|
|
func (r *root) Match(domain string) bool {
|
|
return r.root.Match(strings.Split(domain, "."))
|
|
}
|
|
|
|
// Sort will sort the entire radix trie ensuring that
|
|
// child nodes are stored in alphabetical order. This
|
|
// MUST be done to finalize the domain cache in order
|
|
// to speed up the binary search of node child parts.
|
|
func (r *root) Sort() {
|
|
r.root.sort()
|
|
}
|
|
|
|
// String returns a string representation of node (and its descendants).
|
|
func (r *root) String() string {
|
|
buf := new(strings.Builder)
|
|
r.root.WriteStr(buf, "")
|
|
return buf.String()
|
|
}
|
|
|
|
type node struct {
|
|
part string
|
|
child []*node
|
|
}
|
|
|
|
func (n *node) Add(parts []string) {
|
|
if len(parts) == 0 {
|
|
panic("invalid domain")
|
|
}
|
|
|
|
for {
|
|
// Pop next domain part.
|
|
i := len(parts) - 1
|
|
part := parts[i]
|
|
parts = parts[:i]
|
|
|
|
var nn *node
|
|
|
|
// Look for existing child node
|
|
// that matches next domain part.
|
|
for _, child := range n.child {
|
|
if child.part == part {
|
|
nn = child
|
|
break
|
|
}
|
|
}
|
|
|
|
if nn == nil {
|
|
// Alloc new child node.
|
|
nn = &node{part: part}
|
|
n.child = append(n.child, nn)
|
|
}
|
|
|
|
if len(parts) == 0 {
|
|
// Drop all children here as
|
|
// this is a higher-level domain
|
|
// than that we previously had.
|
|
nn.child = nil
|
|
return
|
|
}
|
|
|
|
// Re-iter with
|
|
// child node.
|
|
n = nn
|
|
}
|
|
}
|
|
|
|
func (n *node) Match(parts []string) bool {
|
|
for len(parts) > 0 {
|
|
// Pop next domain part.
|
|
i := len(parts) - 1
|
|
part := parts[i]
|
|
parts = parts[:i]
|
|
|
|
// Look for existing child
|
|
// that matches next part.
|
|
nn := n.getChild(part)
|
|
|
|
if nn == nil {
|
|
// No match :(
|
|
return false
|
|
}
|
|
|
|
if len(nn.child) == 0 {
|
|
// It's a match!
|
|
return true
|
|
}
|
|
|
|
// Re-iter with
|
|
// child node.
|
|
n = nn
|
|
}
|
|
|
|
// Ran out of parts
|
|
// without a match.
|
|
return false
|
|
}
|
|
|
|
// getChild fetches child node with given domain part string
|
|
// using a binary search. THIS ASSUMES CHILDREN ARE SORTED.
|
|
func (n *node) getChild(part string) *node {
|
|
i, j := 0, len(n.child)
|
|
|
|
for i < j {
|
|
// avoid overflow when computing h
|
|
h := int(uint(i+j) >> 1) // #nosec G115
|
|
// i ≤ h < j
|
|
|
|
if n.child[h].part < part {
|
|
// preserves:
|
|
// n.child[i-1].part != part
|
|
i = h + 1
|
|
} else {
|
|
// preserves:
|
|
// n.child[h].part == part
|
|
j = h
|
|
}
|
|
}
|
|
|
|
if i >= len(n.child) || n.child[i].part != part {
|
|
return nil // no match
|
|
}
|
|
|
|
return n.child[i]
|
|
}
|
|
|
|
func (n *node) sort() {
|
|
// Sort this node's slice of child nodes.
|
|
slices.SortFunc(n.child, func(i, j *node) int {
|
|
const k = -1
|
|
switch {
|
|
case i.part < j.part:
|
|
return +k
|
|
case i.part > j.part:
|
|
return -k
|
|
default:
|
|
return 0
|
|
}
|
|
})
|
|
|
|
// Sort each child node's children.
|
|
for _, child := range n.child {
|
|
child.sort()
|
|
}
|
|
}
|
|
|
|
func (n *node) WriteStr(buf *strings.Builder, prefix string) {
|
|
if prefix != "" {
|
|
// Suffix joining '.'
|
|
prefix += "."
|
|
}
|
|
|
|
// Append current part.
|
|
prefix += n.part
|
|
|
|
// Dump current prefix state.
|
|
buf.WriteString(prefix)
|
|
buf.WriteByte('\n')
|
|
|
|
// Iterate through node children.
|
|
for _, child := range n.child {
|
|
child.WriteStr(buf, prefix)
|
|
}
|
|
}
|