[bugfix] Better Postgres search case insensitivity (#2526)

* [bugfix] Better Postgres search case insensitivity

* use ilike for postgres
This commit is contained in:
tobi 2024-01-16 18:50:17 +01:00 committed by GitHub
parent 486585890d
commit c5eced5fd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 29 deletions

View file

@ -133,8 +133,7 @@ func (s *searchDB) SearchForAccounts(
// Normalize it and just look for // Normalize it and just look for
// usernames that start with query. // usernames that start with query.
query = query[1:] query = query[1:]
subQ := s.accountUsername() q = whereStartsLike(q, bun.Ident("account.username"), query)
q = whereStartsLike(q, subQ, query)
} else { } else {
// Query looks like arbitrary string. // Query looks like arbitrary string.
// Search using LIKE for matches of query // Search using LIKE for matches of query
@ -199,14 +198,6 @@ func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery {
Where("? = ?", bun.Ident("follow.account_id"), accountID) Where("? = ?", bun.Ident("follow.account_id"), accountID)
} }
// accountUsername returns a subquery that just selects
// from account usernames, without concatenation.
func (s *searchDB) accountUsername() *bun.SelectQuery {
return s.db.
NewSelect().
Column("account.username")
}
// accountText returns a subquery that selects a concatenation // accountText returns a subquery that selects a concatenation
// of account username and display name as "account_text". If // of account username and display name as "account_text". If
// `following` is true, then account note will also be included // `following` is true, then account note will also be included
@ -242,11 +233,8 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {
// different number of placeholders depending on // different number of placeholders depending on
// following/not following. COALESCE calls ensure // following/not following. COALESCE calls ensure
// that we're not trying to concatenate null values. // that we're not trying to concatenate null values.
//
// SQLite search is case insensitive. switch d := s.db.Dialect().Name(); {
// Postgres searches get lowercased.
d := s.db.Dialect().Name()
switch {
case d == dialect.SQLite && following: case d == dialect.SQLite && following:
query = "? || COALESCE(?, ?) || COALESCE(?, ?) AS ?" query = "? || COALESCE(?, ?) || COALESCE(?, ?) AS ?"
@ -255,13 +243,13 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {
query = "? || COALESCE(?, ?) AS ?" query = "? || COALESCE(?, ?) AS ?"
case d == dialect.PG && following: case d == dialect.PG && following:
query = "LOWER(CONCAT(?, COALESCE(?, ?), COALESCE(?, ?))) AS ?" query = "CONCAT(?, COALESCE(?, ?), COALESCE(?, ?)) AS ?"
case d == dialect.PG && !following: case d == dialect.PG && !following:
query = "LOWER(CONCAT(?, COALESCE(?, ?))) AS ?" query = "CONCAT(?, COALESCE(?, ?)) AS ?"
default: default:
panic("db conn was neither pg not sqlite") log.Panicf(nil, "db conn %s was neither pg nor sqlite", d)
} }
return accountText.ColumnExpr(query, args...) return accountText.ColumnExpr(query, args...)
@ -385,10 +373,7 @@ func (s *searchDB) statusText() *bun.SelectQuery {
// SQLite and Postgres use different // SQLite and Postgres use different
// syntaxes for concatenation. // syntaxes for concatenation.
// switch d := s.db.Dialect().Name(); d {
// SQLite search is case insensitive.
// Postgres searches get lowercased.
switch s.db.Dialect().Name() {
case dialect.SQLite: case dialect.SQLite:
statusText = statusText.ColumnExpr( statusText = statusText.ColumnExpr(
@ -398,12 +383,12 @@ func (s *searchDB) statusText() *bun.SelectQuery {
case dialect.PG: case dialect.PG:
statusText = statusText.ColumnExpr( statusText = statusText.ColumnExpr(
"LOWER(CONCAT(?, COALESCE(?, ?))) AS ?", "CONCAT(?, COALESCE(?, ?)) AS ?",
bun.Ident("status.content"), bun.Ident("status.content_warning"), "", bun.Ident("status.content"), bun.Ident("status.content_warning"), "",
bun.Ident("status_text")) bun.Ident("status_text"))
default: default:
panic("db conn was neither pg not sqlite") log.Panicf(nil, "db conn %s was neither pg nor sqlite", d)
} }
return statusText return statusText

View file

@ -46,6 +46,15 @@ func (suite *SearchTestSuite) TestSearchAccounts1HappyWithPrefix() {
suite.Len(accounts, 1) suite.Len(accounts, 1)
} }
func (suite *SearchTestSuite) TestSearchAccounts1HappyWithPrefixUpper() {
testAccount := suite.testAccounts["local_account_1"]
// Query will just look for usernames that start with "1HAPPY".
accounts, err := suite.db.SearchForAccounts(context.Background(), testAccount.ID, "@1HAPPY", "", "", 10, false, 0)
suite.NoError(err)
suite.Len(accounts, 1)
}
func (suite *SearchTestSuite) TestSearchAccounts1HappyNoPrefix() { func (suite *SearchTestSuite) TestSearchAccounts1HappyNoPrefix() {
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]
@ -63,6 +72,14 @@ func (suite *SearchTestSuite) TestSearchAccountsTurtleFollowing() {
suite.Len(accounts, 1) suite.Len(accounts, 1)
} }
func (suite *SearchTestSuite) TestSearchAccountsTurtleFollowingUpper() {
testAccount := suite.testAccounts["local_account_1"]
accounts, err := suite.db.SearchForAccounts(context.Background(), testAccount.ID, "TURTLE", "", "", 10, true, 0)
suite.NoError(err)
suite.Len(accounts, 1)
}
func (suite *SearchTestSuite) TestSearchAccountsPostFollowing() { func (suite *SearchTestSuite) TestSearchAccountsPostFollowing() {
testAccount := suite.testAccounts["local_account_1"] testAccount := suite.testAccounts["local_account_1"]

View file

@ -23,8 +23,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
) )
// likeEscaper is a thread-safe string replacer which escapes // likeEscaper is a thread-safe string replacer which escapes
@ -37,10 +39,29 @@ var likeEscaper = strings.NewReplacer(
`_`, `\_`, // Exactly one char. `_`, `\_`, // Exactly one char.
) )
// likeOperator returns an appropriate LIKE or
// ILIKE operator for the given query's dialect.
func likeOperator(query *bun.SelectQuery) string {
const (
like = "LIKE"
ilike = "ILIKE"
)
d := query.Dialect().Name()
if d == dialect.SQLite {
return like
} else if d == dialect.PG {
return ilike
}
log.Panicf(nil, "db conn %s was neither pg nor sqlite", d)
return ""
}
// whereLike appends a WHERE clause to the // whereLike appends a WHERE clause to the
// given SelectQuery, which searches for // given SelectQuery, which searches for
// matches of `search` in the given subQuery // matches of `search` in the given subQuery
// using LIKE. // using LIKE (SQLite) or ILIKE (Postgres).
func whereLike( func whereLike(
query *bun.SelectQuery, query *bun.SelectQuery,
subject interface{}, subject interface{},
@ -54,11 +75,14 @@ func whereLike(
// zero or more chars around the query. // zero or more chars around the query.
search = `%` + search + `%` search = `%` + search + `%`
// Get appropriate operator.
like := likeOperator(query)
// Append resulting WHERE // Append resulting WHERE
// clause to the main query. // clause to the main query.
return query.Where( return query.Where(
"(?) LIKE ? ESCAPE ?", "(?) ? ? ESCAPE ?",
subject, search, `\`, subject, bun.Safe(like), search, `\`,
) )
} }
@ -78,11 +102,14 @@ func whereStartsLike(
// zero or more chars after the query. // zero or more chars after the query.
search += `%` search += `%`
// Get appropriate operator.
like := likeOperator(query)
// Append resulting WHERE // Append resulting WHERE
// clause to the main query. // clause to the main query.
return query.Where( return query.Where(
"(?) LIKE ? ESCAPE ?", "(?) ? ? ESCAPE ?",
subject, search, `\`, subject, bun.Safe(like), search, `\`,
) )
} }