Fix sql identifier escaping in datastore feed (#4746)

This commit is contained in:
Robert Kaussow 2025-01-19 22:41:59 +01:00 committed by GitHub
parent 0b657236d9
commit 08021ca593
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 7 deletions

View file

@ -15,12 +15,15 @@
package datastore package datastore
import ( import (
"fmt"
"xorm.io/builder" "xorm.io/builder"
"go.woodpecker-ci.org/woodpecker/v3/server/model" "go.woodpecker-ci.org/woodpecker/v3/server/model"
) )
var feedItemSelect = `repos.id as repo_id, func (s storage) getFeedSelect() string {
const feedTemplate = `repos.id as repo_id,
pipelines.id as pipeline_id, pipelines.id as pipeline_id,
pipelines.number as pipeline_number, pipelines.number as pipeline_number,
pipelines.event as pipeline_event, pipelines.event as pipeline_event,
@ -28,7 +31,7 @@ pipelines.status as pipeline_status,
pipelines.created as pipeline_created, pipelines.created as pipeline_created,
pipelines.started as pipeline_started, pipelines.started as pipeline_started,
pipelines.finished as pipeline_finished, pipelines.finished as pipeline_finished,
'pipelines.commit' as pipeline_commit, pipelines.%s as pipeline_commit,
pipelines.branch as pipeline_branch, pipelines.branch as pipeline_branch,
pipelines.ref as pipeline_ref, pipelines.ref as pipeline_ref,
pipelines.refspec as pipeline_refspec, pipelines.refspec as pipeline_refspec,
@ -38,10 +41,13 @@ pipelines.author as pipeline_author,
pipelines.email as pipeline_email, pipelines.email as pipeline_email,
pipelines.avatar as pipeline_avatar` pipelines.avatar as pipeline_avatar`
return fmt.Sprintf(feedTemplate, s.quoteIdentifier("commit"))
}
func (s storage) GetPipelineQueue() ([]*model.Feed, error) { func (s storage) GetPipelineQueue() ([]*model.Feed, error) {
feed := make([]*model.Feed, 0, perPage) feed := make([]*model.Feed, 0, perPage)
err := s.engine.Table("pipelines"). err := s.engine.Table("pipelines").
Select(feedItemSelect). Select(s.getFeedSelect()).
Join("INNER", "repos", "pipelines.repo_id = repos.id"). Join("INNER", "repos", "pipelines.repo_id = repos.id").
In("pipelines.status", model.StatusPending, model.StatusRunning). In("pipelines.status", model.StatusPending, model.StatusRunning).
Find(&feed) Find(&feed)
@ -51,7 +57,7 @@ func (s storage) GetPipelineQueue() ([]*model.Feed, error) {
func (s storage) UserFeed(user *model.User) ([]*model.Feed, error) { func (s storage) UserFeed(user *model.User) ([]*model.Feed, error) {
feed := make([]*model.Feed, 0, perPage) feed := make([]*model.Feed, 0, perPage)
err := s.engine.Table("repos"). err := s.engine.Table("repos").
Select(feedItemSelect). Select(s.getFeedSelect()).
Join("INNER", "perms", "repos.id = perms.repo_id"). Join("INNER", "perms", "repos.id = perms.repo_id").
Join("INNER", "pipelines", "repos.id = pipelines.repo_id"). Join("INNER", "pipelines", "repos.id = pipelines.repo_id").
Where(userPushOrAdminCondition(user.ID)). Where(userPushOrAdminCondition(user.ID)).
@ -66,7 +72,7 @@ func (s storage) RepoListLatest(user *model.User) ([]*model.Feed, error) {
feed := make([]*model.Feed, 0, perPage) feed := make([]*model.Feed, 0, perPage)
err := s.engine.Table("repos"). err := s.engine.Table("repos").
Select(feedItemSelect). Select(s.getFeedSelect()).
Join("INNER", "perms", "repos.id = perms.repo_id"). Join("INNER", "perms", "repos.id = perms.repo_id").
Join("LEFT", "pipelines", "pipelines.id = "+`( Join("LEFT", "pipelines", "pipelines.id = "+`(
SELECT pipelines.id FROM pipelines SELECT pipelines.id FROM pipelines

View file

@ -48,14 +48,37 @@ func TestGetPipelineQueue(t *testing.T) {
assert.NoError(t, store.PermUpsert(perm)) assert.NoError(t, store.PermUpsert(perm))
} }
pipeline1 := &model.Pipeline{ pipeline1 := &model.Pipeline{
RepoID: repo1.ID, RepoID: repo1.ID,
Status: model.StatusPending, Status: model.StatusPending,
Number: 1,
Event: "push",
Commit: "abc123",
Branch: "main",
Ref: "refs/heads/main",
Message: "Initial commit",
Author: "joe",
Email: "foo@bar.com",
Title: "First pipeline",
} }
assert.NoError(t, store.CreatePipeline(pipeline1)) assert.NoError(t, store.CreatePipeline(pipeline1))
feed, err := store.GetPipelineQueue() feed, err := store.GetPipelineQueue()
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, feed, 1) assert.Len(t, feed, 1)
feedItem := feed[0]
assert.Equal(t, repo1.ID, feedItem.RepoID)
assert.Equal(t, pipeline1.ID, feedItem.ID)
assert.Equal(t, pipeline1.Number, feedItem.Number)
assert.EqualValues(t, pipeline1.Event, feedItem.Event)
assert.EqualValues(t, pipeline1.Status, feedItem.Status)
assert.Equal(t, pipeline1.Commit, feedItem.Commit)
assert.Equal(t, pipeline1.Branch, feedItem.Branch)
assert.Equal(t, pipeline1.Ref, feedItem.Ref)
assert.Equal(t, pipeline1.Title, feedItem.Title)
assert.Equal(t, pipeline1.Message, feedItem.Message)
assert.Equal(t, pipeline1.Author, feedItem.Author)
assert.Equal(t, pipeline1.Email, feedItem.Email)
} }
func TestUserFeed(t *testing.T) { func TestUserFeed(t *testing.T) {

View file

@ -76,3 +76,15 @@ func callerName(skip int) string {
} }
return fnName return fnName
} }
func (s storage) quoteIdentifier(identifier string) string {
driver := s.engine.DriverName()
switch driver {
case DriverMysql:
return "`" + identifier + "`"
case DriverPostgres, DriverSqlite:
return "\"" + identifier + "\""
default:
return identifier
}
}