Merge branch 'main' into relax_account_uri_uniqueness

This commit is contained in:
tobi 2025-04-06 11:00:08 +02:00
commit 39339dc832
71 changed files with 1735 additions and 910 deletions

View file

@ -113,6 +113,27 @@ nothanks.com,suspend,false,false,,false
JSON lists use content type `application/json`.
```json
[
{
"domain": "bumfaces.net",
"suspended_at": "2020-05-13T13:29:12.000Z",
"comment": "big jerks"
},
{
"domain": "peepee.poopoo",
"suspended_at": "2020-05-13T13:29:12.000Z",
"comment": "harassment"
},
{
"domain": "nothanks.com",
"suspended_at": "2020-05-13T13:29:12.000Z"
}
]
```
As an alternative to `"comment"`, `"public_comment"` will also work:
```json
[
{

View file

@ -1099,13 +1099,22 @@ definitions:
domain:
description: Domain represents a remote domain
properties:
comment:
description: |-
If the domain is blocked, what's the publicly-stated reason for the block.
Alternative to `public_comment` to be used when serializing/deserializing via /api/v1/instance.
example: they smell
type: string
x-go-name: Comment
domain:
description: The hostname of the domain.
example: example.org
type: string
x-go-name: Domain
public_comment:
description: If the domain is blocked, what's the publicly-stated reason for the block.
description: |-
If the domain is blocked, what's the publicly-stated reason for the block.
Alternative to `comment` to be used when serializing/deserializing NOT via /api/v1/instance.
example: they smell
type: string
x-go-name: PublicComment
@ -1124,6 +1133,13 @@ definitions:
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
domainPermission:
properties:
comment:
description: |-
If the domain is blocked, what's the publicly-stated reason for the block.
Alternative to `public_comment` to be used when serializing/deserializing via /api/v1/instance.
example: they smell
type: string
x-go-name: Comment
created_at:
description: Time at which the permission entry was created (ISO 8601 Datetime).
example: "2021-07-30T09:20:25+00:00"
@ -1162,7 +1178,9 @@ definitions:
type: string
x-go-name: PrivateComment
public_comment:
description: If the domain is blocked, what's the publicly-stated reason for the block.
description: |-
If the domain is blocked, what's the publicly-stated reason for the block.
Alternative to `comment` to be used when serializing/deserializing NOT via /api/v1/instance.
example: they smell
type: string
x-go-name: PublicComment
@ -5823,6 +5841,53 @@ paths:
summary: View domain allow with the given ID.
tags:
- admin
put:
consumes:
- multipart/form-data
operationId: domainAllowUpdate
parameters:
- description: The id of the domain allow.
in: path
name: id
required: true
type: string
- description: Obfuscate the name of the domain when serving it publicly. Eg., `example.org` becomes something like `ex***e.org`.
in: formData
name: obfuscate
type: boolean
- description: Public comment about this domain allow. This will be displayed alongside the domain allow if you choose to share allows.
in: formData
name: public_comment
type: string
- description: Private comment about this domain allow. Will only be shown to other admins, so this is a useful way of internally keeping track of why a certain domain ended up allowed.
in: formData
name: private_comment
type: string
produces:
- application/json
responses:
"200":
description: The updated domain allow.
schema:
$ref: '#/definitions/domainPermission'
"400":
description: bad request
"401":
description: unauthorized
"403":
description: forbidden
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- admin:write:domain_allows
summary: Update a single domain allow.
tags:
- admin
/api/v1/admin/domain_blocks:
get:
operationId: domainBlocksGet
@ -5990,6 +6055,53 @@ paths:
summary: View domain block with the given ID.
tags:
- admin
put:
consumes:
- multipart/form-data
operationId: domainBlockUpdate
parameters:
- description: The id of the domain block.
in: path
name: id
required: true
type: string
- description: Obfuscate the name of the domain when serving it publicly. Eg., `example.org` becomes something like `ex***e.org`.
in: formData
name: obfuscate
type: boolean
- description: Public comment about this domain block. This will be displayed alongside the domain block if you choose to share blocks.
in: formData
name: public_comment
type: string
- description: Private comment about this domain block. Will only be shown to other admins, so this is a useful way of internally keeping track of why a certain domain ended up blocked.
in: formData
name: private_comment
type: string
produces:
- application/json
responses:
"200":
description: The updated domain block.
schema:
$ref: '#/definitions/domainPermission'
"400":
description: bad request
"401":
description: unauthorized
"403":
description: forbidden
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- admin:write:domain_blocks
summary: Update a single domain block.
tags:
- admin
/api/v1/admin/domain_keys_expire:
post:
consumes:

2
go.mod
View file

@ -54,7 +54,7 @@ require (
github.com/miekg/dns v1.1.64
github.com/minio/minio-go/v7 v7.0.85
github.com/mitchellh/mapstructure v1.5.0
github.com/ncruces/go-sqlite3 v0.24.0
github.com/ncruces/go-sqlite3 v0.25.0
github.com/oklog/ulid v1.3.1
github.com/prometheus/client_golang v1.21.1
github.com/rivo/uniseg v0.4.7

4
go.sum generated
View file

@ -322,8 +322,8 @@ github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-sqlite3 v0.24.0 h1:Z4jfmzu2NCd4SmyFwLT2OmF3EnTZbqwATvdiuNHNhLA=
github.com/ncruces/go-sqlite3 v0.24.0/go.mod h1:/Vs8ACZHjJ1SA6E9RZUn3EyB1OP3nDQ4z/ar+0fplTQ=
github.com/ncruces/go-sqlite3 v0.25.0 h1:trugKUs98Zwy9KwRr/EUxZHL92LYt7UqcKqAfpGpK+I=
github.com/ncruces/go-sqlite3 v0.25.0/go.mod h1:n6Z7036yFilJx04yV0mi5JWaF66rUmXn1It9Ux8dx68=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=

View file

@ -102,12 +102,14 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H
attachHandler(http.MethodPost, DomainBlocksPath, m.DomainBlocksPOSTHandler)
attachHandler(http.MethodGet, DomainBlocksPath, m.DomainBlocksGETHandler)
attachHandler(http.MethodGet, DomainBlocksPathWithID, m.DomainBlockGETHandler)
attachHandler(http.MethodPut, DomainBlocksPathWithID, m.DomainBlockUpdatePUTHandler)
attachHandler(http.MethodDelete, DomainBlocksPathWithID, m.DomainBlockDELETEHandler)
// domain allow stuff
attachHandler(http.MethodPost, DomainAllowsPath, m.DomainAllowsPOSTHandler)
attachHandler(http.MethodGet, DomainAllowsPath, m.DomainAllowsGETHandler)
attachHandler(http.MethodGet, DomainAllowsPathWithID, m.DomainAllowGETHandler)
attachHandler(http.MethodPut, DomainAllowsPathWithID, m.DomainAllowUpdatePUTHandler)
attachHandler(http.MethodDelete, DomainAllowsPathWithID, m.DomainAllowDELETEHandler)
// domain permission draft stuff

View file

@ -0,0 +1,91 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// DomainAllowUpdatePUTHandler swagger:operation PUT /api/v1/admin/domain_allows/{id} domainAllowUpdate
//
// Update a single domain allow.
//
// ---
// tags:
// - admin
//
// consumes:
// - multipart/form-data
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: The id of the domain allow.
// in: path
// required: true
// -
// name: obfuscate
// in: formData
// description: >-
// Obfuscate the name of the domain when serving it publicly.
// Eg., `example.org` becomes something like `ex***e.org`.
// type: boolean
// -
// name: public_comment
// in: formData
// description: >-
// Public comment about this domain allow.
// This will be displayed alongside the domain allow if you choose to share allows.
// type: string
// -
// name: private_comment
// in: formData
// description: >-
// Private comment about this domain allow. Will only be shown to other admins, so this
// is a useful way of internally keeping track of why a certain domain ended up allowed.
// type: string
//
// security:
// - OAuth2 Bearer:
// - admin:write:domain_allows
//
// responses:
// '200':
// description: The updated domain allow.
// schema:
// "$ref": "#/definitions/domainPermission"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) DomainAllowUpdatePUTHandler(c *gin.Context) {
m.updateDomainPermission(c, gtsmodel.DomainPermissionAllow)
}

View file

@ -0,0 +1,91 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// DomainBlockUpdatePUTHandler swagger:operation PUT /api/v1/admin/domain_blocks/{id} domainBlockUpdate
//
// Update a single domain block.
//
// ---
// tags:
// - admin
//
// consumes:
// - multipart/form-data
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: The id of the domain block.
// in: path
// required: true
// -
// name: obfuscate
// in: formData
// description: >-
// Obfuscate the name of the domain when serving it publicly.
// Eg., `example.org` becomes something like `ex***e.org`.
// type: boolean
// -
// name: public_comment
// in: formData
// description: >-
// Public comment about this domain block.
// This will be displayed alongside the domain block if you choose to share blocks.
// type: string
// -
// name: private_comment
// in: formData
// description: >-
// Private comment about this domain block. Will only be shown to other admins, so this
// is a useful way of internally keeping track of why a certain domain ended up blocked.
// type: string
//
// security:
// - OAuth2 Bearer:
// - admin:write:domain_blocks
//
// responses:
// '200':
// description: The updated domain block.
// schema:
// "$ref": "#/definitions/domainPermission"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) DomainBlockUpdatePUTHandler(c *gin.Context) {
m.updateDomainPermission(c, gtsmodel.DomainPermissionBlock)
}

View file

@ -29,6 +29,7 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type singleDomainPermCreate func(
@ -112,7 +113,7 @@ func (m *Module) createDomainPermissions(
if importing && form.Domains.Size == 0 {
err = errors.New("import was specified but list of domains is empty")
} else if !importing && form.Domain == "" {
err = errors.New("empty domain provided")
err = errors.New("no domain provided")
}
if err != nil {
@ -122,14 +123,14 @@ func (m *Module) createDomainPermissions(
if !importing {
// Single domain permission creation.
domainBlock, _, errWithCode := single(
perm, _, errWithCode := single(
c.Request.Context(),
permType,
authed.Account,
form.Domain,
form.Obfuscate,
form.PublicComment,
form.PrivateComment,
util.PtrOrZero(form.Obfuscate),
util.PtrOrZero(form.PublicComment),
util.PtrOrZero(form.PrivateComment),
"", // No sub ID for single perm creation.
)
@ -138,7 +139,7 @@ func (m *Module) createDomainPermissions(
return
}
apiutil.JSON(c, http.StatusOK, domainBlock)
apiutil.JSON(c, http.StatusOK, perm)
return
}
@ -177,6 +178,82 @@ func (m *Module) createDomainPermissions(
apiutil.JSON(c, http.StatusOK, domainPerms)
}
func (m *Module) updateDomainPermission(
c *gin.Context,
permType gtsmodel.DomainPermissionType,
) {
// Scope differs based on permType.
var requireScope apiutil.Scope
if permType == gtsmodel.DomainPermissionBlock {
requireScope = apiutil.ScopeAdminWriteDomainBlocks
} else {
requireScope = apiutil.ScopeAdminWriteDomainAllows
}
authed, errWithCode := apiutil.TokenAuth(c,
true, true, true, true,
requireScope,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if !*authed.User.Admin {
err := fmt.Errorf("user %s not an admin", authed.User.ID)
apiutil.ErrorHandler(c, gtserror.NewErrorForbidden(err, err.Error()), m.processor.InstanceGetV1)
return
}
if authed.Account.IsMoving() {
apiutil.ForbiddenAfterMove(c)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
permID, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey))
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
// Parse + validate form.
form := new(apimodel.DomainPermissionRequest)
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if form.Obfuscate == nil &&
form.PrivateComment == nil &&
form.PublicComment == nil {
const errText = "empty form submitted"
errWithCode := gtserror.NewErrorBadRequest(errors.New(errText), errText)
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
perm, errWithCode := m.processor.Admin().DomainPermissionUpdate(
c.Request.Context(),
permType,
permID,
form.Obfuscate,
form.PublicComment,
form.PrivateComment,
nil, // Can't update perm sub ID this way yet.
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiutil.JSON(c, http.StatusOK, perm)
}
// deleteDomainPermission deletes a single domain permission (block or allow).
func (m *Module) deleteDomainPermission(
c *gin.Context,

View file

@ -26,6 +26,7 @@ import (
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// DomainPermissionDraftsPOSTHandler swagger:operation POST /api/v1/admin/domain_permission_drafts domainPermissionDraftCreate
@ -148,9 +149,9 @@ func (m *Module) DomainPermissionDraftsPOSTHandler(c *gin.Context) {
authed.Account,
form.Domain,
permType,
form.Obfuscate,
form.PublicComment,
form.PrivateComment,
util.PtrOrZero(form.Obfuscate),
util.PtrOrZero(form.PublicComment),
util.PtrOrZero(form.PrivateComment),
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)

View file

@ -97,14 +97,21 @@ func (suite *DomainPermissionSubscriptionTestTestSuite) TestDomainPermissionSubs
suite.Equal(`[
{
"domain": "bumfaces.net",
"public_comment": "big jerks"
"public_comment": "big jerks",
"obfuscate": false,
"private_comment": ""
},
{
"domain": "peepee.poopoo",
"public_comment": "harassment"
"public_comment": "harassment",
"obfuscate": false,
"private_comment": ""
},
{
"domain": "nothanks.com"
"domain": "nothanks.com",
"public_comment": "",
"obfuscate": false,
"private_comment": ""
}
]`, dst.String())
@ -177,13 +184,22 @@ func (suite *DomainPermissionSubscriptionTestTestSuite) TestDomainPermissionSubs
// Ensure expected.
suite.Equal(`[
{
"domain": "bumfaces.net"
"domain": "bumfaces.net",
"public_comment": "",
"obfuscate": false,
"private_comment": ""
},
{
"domain": "peepee.poopoo"
"domain": "peepee.poopoo",
"public_comment": "",
"obfuscate": false,
"private_comment": ""
},
{
"domain": "nothanks.com"
"domain": "nothanks.com",
"public_comment": "",
"obfuscate": false,
"private_comment": ""
}
]`, dst.String())

View file

@ -136,7 +136,7 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetOnlySuspended() {
{
"domain": "replyguys.com",
"suspended_at": "2020-05-13T13:29:12.000Z",
"public_comment": "reply-guying to tech posts"
"comment": "reply-guying to tech posts"
}
]`, dst.String())
}
@ -186,7 +186,7 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetOnlySuspendedAuthori
{
"domain": "replyguys.com",
"suspended_at": "2020-05-13T13:29:12.000Z",
"public_comment": "reply-guying to tech posts"
"comment": "reply-guying to tech posts"
}
]`, dst.String())
}
@ -219,7 +219,7 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetAll() {
{
"domain": "replyguys.com",
"suspended_at": "2020-05-13T13:29:12.000Z",
"public_comment": "reply-guying to tech posts"
"comment": "reply-guying to tech posts"
}
]`, dst.String())
}
@ -263,12 +263,12 @@ func (suite *InstancePeersGetTestSuite) TestInstancePeersGetAllWithObfuscated()
{
"domain": "o*g.*u**.t**.*or*t.*r**ev**",
"suspended_at": "2021-06-09T10:34:55.000Z",
"public_comment": "just absolutely the worst, wowza"
"comment": "just absolutely the worst, wowza"
},
{
"domain": "replyguys.com",
"suspended_at": "2020-05-13T13:29:12.000Z",
"public_comment": "reply-guying to tech posts"
"comment": "reply-guying to tech posts"
}
]`, dst.String())
}

View file

@ -33,8 +33,13 @@ type Domain struct {
// example: 2021-07-30T09:20:25+00:00
SilencedAt string `json:"silenced_at,omitempty"`
// If the domain is blocked, what's the publicly-stated reason for the block.
// Alternative to `public_comment` to be used when serializing/deserializing via /api/v1/instance.
// example: they smell
PublicComment string `form:"public_comment" json:"public_comment,omitempty"`
Comment *string `form:"comment" json:"comment,omitempty"`
// If the domain is blocked, what's the publicly-stated reason for the block.
// Alternative to `comment` to be used when serializing/deserializing NOT via /api/v1/instance.
// example: they smell
PublicComment *string `form:"public_comment" json:"public_comment,omitempty"`
}
// DomainPermission represents a permission applied to one domain (explicit block/allow).
@ -48,10 +53,10 @@ type DomainPermission struct {
ID string `json:"id,omitempty"`
// Obfuscate the domain name when serving this domain permission entry publicly.
// example: false
Obfuscate bool `json:"obfuscate,omitempty"`
Obfuscate *bool `json:"obfuscate,omitempty"`
// Private comment for this permission entry, visible to this instance's admins only.
// example: they are poopoo
PrivateComment string `json:"private_comment,omitempty"`
PrivateComment *string `json:"private_comment,omitempty"`
// If applicable, the ID of the subscription that caused this domain permission entry to be created.
// example: 01FBW25TF5J67JW3HFHZCSD23K
SubscriptionID string `json:"subscription_id,omitempty"`
@ -80,14 +85,14 @@ type DomainPermissionRequest struct {
// Obfuscate the domain name when displaying this permission entry publicly.
// Ie., instead of 'example.org' show something like 'e**mpl*.or*'.
// example: false
Obfuscate bool `form:"obfuscate" json:"obfuscate"`
Obfuscate *bool `form:"obfuscate" json:"obfuscate"`
// Private comment for other admins on why this permission entry was created.
// example: don't like 'em!!!!
PrivateComment string `form:"private_comment" json:"private_comment"`
PrivateComment *string `form:"private_comment" json:"private_comment"`
// Public comment on why this permission entry was created.
// Will be visible to requesters at /api/v1/instance/peers if this endpoint is exposed.
// example: foss dorks 😫
PublicComment string `form:"public_comment" json:"public_comment"`
PublicComment *string `form:"public_comment" json:"public_comment"`
// Permission type to create (only applies to domain permission drafts, not explicit blocks and allows).
PermissionType string `form:"permission_type" json:"permission_type"`
}

View file

@ -36,7 +36,7 @@ type domainDB struct {
state *state.State
}
func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) {
func (d *domainDB) PutDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) (err error) {
// Normalize the domain as punycode, note the extra
// validation step for domain name write operations.
allow.Domain, err = util.PunifySafely(allow.Domain)
@ -162,7 +162,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
return nil
}
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
func (d *domainDB) PutDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
var err error
// Normalize the domain as punycode, note the extra

View file

@ -46,7 +46,7 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {
suite.NoError(err)
suite.False(blocked)
err = suite.db.CreateDomainBlock(ctx, domainBlock)
err = suite.db.PutDomainBlock(ctx, domainBlock)
suite.NoError(err)
// domain block now exists
@ -75,7 +75,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() {
suite.False(blocked)
// Block this domain.
if err := suite.db.CreateDomainBlock(ctx, domainBlock); err != nil {
if err := suite.db.PutDomainBlock(ctx, domainBlock); err != nil {
suite.FailNow(err.Error())
}
@ -96,7 +96,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() {
CreatedByAccount: suite.testAccounts["admin_account"],
}
if err := suite.db.CreateDomainAllow(ctx, domainAllow); err != nil {
if err := suite.db.PutDomainAllow(ctx, domainAllow); err != nil {
suite.FailNow(err.Error())
}
@ -124,7 +124,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() {
suite.NoError(err)
suite.False(blocked)
err = suite.db.CreateDomainBlock(ctx, domainBlock)
err = suite.db.PutDomainBlock(ctx, domainBlock)
suite.NoError(err)
// Start with the base block domain
@ -164,7 +164,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() {
suite.NoError(err)
suite.False(blocked)
err = suite.db.CreateDomainBlock(ctx, domainBlock)
err = suite.db.PutDomainBlock(ctx, domainBlock)
suite.NoError(err)
// domain block now exists
@ -200,7 +200,7 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() {
suite.NoError(err)
suite.False(blocked)
err = suite.db.CreateDomainBlock(ctx, domainBlock)
err = suite.db.PutDomainBlock(ctx, domainBlock)
suite.NoError(err)
// domain block now exists
@ -232,7 +232,7 @@ func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() {
}
for _, block := range blocks {
if err := suite.db.CreateDomainBlock(ctx, block); err != nil {
if err := suite.db.PutDomainBlock(ctx, block); err != nil {
suite.FailNow(err.Error())
}
}

View file

@ -80,7 +80,7 @@ func (suite *DomainPermissionSubscriptionTestSuite) TestCount() {
// Whack the perms in the db.
for _, perm := range perms {
if err := suite.state.DB.CreateDomainBlock(ctx, perm); err != nil {
if err := suite.state.DB.PutDomainBlock(ctx, perm); err != nil {
suite.FailNow(err.Error())
}
}

View file

@ -31,8 +31,8 @@ type Domain interface {
Block/allow storage + retrieval functions.
*/
// CreateDomainAllow puts the given instance-level domain allow into the database.
CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error
// PutDomainAllow puts the given instance-level domain allow into the database.
PutDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error
// GetDomainAllow returns one instance-level domain allow with the given domain, if it exists.
GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error)
@ -49,8 +49,8 @@ type Domain interface {
// DeleteDomainAllow deletes an instance-level domain allow with the given domain, if it exists.
DeleteDomainAllow(ctx context.Context, domain string) error
// CreateDomainBlock puts the given instance-level domain block into the database.
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error
// PutDomainBlock puts the given instance-level domain block into the database.
PutDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error
// GetDomainBlock returns one instance-level domain block with the given domain, if it exists.
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error)

View file

@ -26,7 +26,7 @@ type DomainAllow struct {
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
Domain string `bun:",nullzero,notnull"` // domain to allow. Eg. 'whatever.com'
CreatedByAccountID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this allow
CreatedByAccount *Account `bun:"rel:belongs-to"` // Account corresponding to createdByAccountID
CreatedByAccount *Account `bun:"-"` // Account corresponding to createdByAccountID
PrivateComment string `bun:""` // Private comment on this allow, viewable to admins
PublicComment string `bun:""` // Public comment on this allow, viewable (optionally) by everyone
Obfuscate *bool `bun:",nullzero,notnull,default:false"` // whether the domain name should appear obfuscated when displaying it publicly

View file

@ -26,7 +26,7 @@ type DomainBlock struct {
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
Domain string `bun:",nullzero,notnull"` // domain to block. Eg. 'whatever.com'
CreatedByAccountID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this block
CreatedByAccount *Account `bun:"rel:belongs-to"` // Account corresponding to createdByAccountID
CreatedByAccount *Account `bun:"-"` // Account corresponding to createdByAccountID
PrivateComment string `bun:""` // Private comment on this block, viewable to admins
PublicComment string `bun:""` // Public comment on this block, viewable (optionally) by everyone
Obfuscate *bool `bun:",nullzero,notnull,default:false"` // whether the domain name should appear obfuscated when displaying it publicly

View file

@ -60,7 +60,7 @@ func (p *Processor) createDomainAllow(
}
// Insert the new allow into the database.
if err := p.state.DB.CreateDomainAllow(ctx, domainAllow); err != nil {
if err := p.state.DB.PutDomainAllow(ctx, domainAllow); err != nil {
err = gtserror.Newf("db error putting domain allow %s: %w", domain, err)
return nil, "", gtserror.NewErrorInternalError(err)
}
@ -92,6 +92,54 @@ func (p *Processor) createDomainAllow(
return apiDomainAllow, action.ID, nil
}
func (p *Processor) updateDomainAllow(
ctx context.Context,
domainAllowID string,
obfuscate *bool,
publicComment *string,
privateComment *string,
subscriptionID *string,
) (*apimodel.DomainPermission, gtserror.WithCode) {
domainAllow, err := p.state.DB.GetDomainAllowByID(ctx, domainAllowID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// Real error.
err = gtserror.Newf("db error getting domain allow: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
// There are just no entries for this ID.
err = fmt.Errorf("no domain allow entry exists with ID %s", domainAllowID)
return nil, gtserror.NewErrorNotFound(err, err.Error())
}
var columns []string
if obfuscate != nil {
domainAllow.Obfuscate = obfuscate
columns = append(columns, "obfuscate")
}
if publicComment != nil {
domainAllow.PublicComment = *publicComment
columns = append(columns, "public_comment")
}
if privateComment != nil {
domainAllow.PrivateComment = *privateComment
columns = append(columns, "private_comment")
}
if subscriptionID != nil {
domainAllow.SubscriptionID = *subscriptionID
columns = append(columns, "subscription_id")
}
// Update the domain allow.
if err := p.state.DB.UpdateDomainAllow(ctx, domainAllow, columns...); err != nil {
err = gtserror.Newf("db error updating domain allow: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiDomainPerm(ctx, domainAllow, false)
}
func (p *Processor) deleteDomainAllow(
ctx context.Context,
adminAcct *gtsmodel.Account,

View file

@ -60,7 +60,7 @@ func (p *Processor) createDomainBlock(
}
// Insert the new block into the database.
if err := p.state.DB.CreateDomainBlock(ctx, domainBlock); err != nil {
if err := p.state.DB.PutDomainBlock(ctx, domainBlock); err != nil {
err = gtserror.Newf("db error putting domain block %s: %w", domain, err)
return nil, "", gtserror.NewErrorInternalError(err)
}
@ -93,6 +93,54 @@ func (p *Processor) createDomainBlock(
return apiDomainBlock, action.ID, nil
}
func (p *Processor) updateDomainBlock(
ctx context.Context,
domainBlockID string,
obfuscate *bool,
publicComment *string,
privateComment *string,
subscriptionID *string,
) (*apimodel.DomainPermission, gtserror.WithCode) {
domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, domainBlockID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// Real error.
err = gtserror.Newf("db error getting domain block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
// There are just no entries for this ID.
err = fmt.Errorf("no domain block entry exists with ID %s", domainBlockID)
return nil, gtserror.NewErrorNotFound(err, err.Error())
}
var columns []string
if obfuscate != nil {
domainBlock.Obfuscate = obfuscate
columns = append(columns, "obfuscate")
}
if publicComment != nil {
domainBlock.PublicComment = *publicComment
columns = append(columns, "public_comment")
}
if privateComment != nil {
domainBlock.PrivateComment = *privateComment
columns = append(columns, "private_comment")
}
if subscriptionID != nil {
domainBlock.SubscriptionID = *subscriptionID
columns = append(columns, "subscription_id")
}
// Update the domain block.
if err := p.state.DB.UpdateDomainBlock(ctx, domainBlock, columns...); err != nil {
err = gtserror.Newf("db error updating domain block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiDomainPerm(ctx, domainBlock, false)
}
func (p *Processor) deleteDomainBlock(
ctx context.Context,
adminAcct *gtsmodel.Account,

View file

@ -18,6 +18,7 @@
package admin
import (
"cmp"
"context"
"encoding/json"
"errors"
@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// DomainPermissionCreate creates an instance-level permission
@ -84,6 +86,50 @@ func (p *Processor) DomainPermissionCreate(
}
}
// DomainPermissionUpdate updates a domain permission
// of the given permissionType, with the given ID.
func (p *Processor) DomainPermissionUpdate(
ctx context.Context,
permissionType gtsmodel.DomainPermissionType,
permID string,
obfuscate *bool,
publicComment *string,
privateComment *string,
subscriptionID *string,
) (*apimodel.DomainPermission, gtserror.WithCode) {
switch permissionType {
// Explicitly block a domain.
case gtsmodel.DomainPermissionBlock:
return p.updateDomainBlock(
ctx,
permID,
obfuscate,
publicComment,
privateComment,
subscriptionID,
)
// Explicitly allow a domain.
case gtsmodel.DomainPermissionAllow:
return p.updateDomainAllow(
ctx,
permID,
obfuscate,
publicComment,
privateComment,
subscriptionID,
)
// 🎵 Why don't we all strap bombs to our chests,
// and ride our bikes to the next G7 picnic?
// Seems easier with every clock-tick. 🎵
default:
err := gtserror.Newf("unrecognized permission type %d", permissionType)
return nil, gtserror.NewErrorInternalError(err)
}
}
// DomainPermissionDelete removes one domain block with the given ID,
// and processes side effects of removing the block asynchronously.
//
@ -153,14 +199,14 @@ func (p *Processor) DomainPermissionsImport(
}
defer file.Close()
// Parse file as slice of domain blocks.
domainPerms := make([]*apimodel.DomainPermission, 0)
if err := json.NewDecoder(file).Decode(&domainPerms); err != nil {
// Parse file as slice of domain permissions.
apiDomainPerms := make([]*apimodel.DomainPermission, 0)
if err := json.NewDecoder(file).Decode(&apiDomainPerms); err != nil {
err = gtserror.Newf("error parsing attachment as domain permissions: %w", err)
return nil, gtserror.NewErrorBadRequest(err, err.Error())
}
count := len(domainPerms)
count := len(apiDomainPerms)
if count == 0 {
err = gtserror.New("error importing domain permissions: 0 entries provided")
return nil, gtserror.NewErrorBadRequest(err, err.Error())
@ -170,52 +216,97 @@ func (p *Processor) DomainPermissionsImport(
// between successes and errors so that the caller can
// try failed imports again if desired.
multiStatusEntries := make([]apimodel.MultiStatusEntry, 0, count)
for _, domainPerm := range domainPerms {
var (
domain = domainPerm.Domain.Domain
obfuscate = domainPerm.Obfuscate
publicComment = domainPerm.PublicComment
privateComment = domainPerm.PrivateComment
subscriptionID = "" // No sub ID for imports.
errWithCode gtserror.WithCode
for _, apiDomainPerm := range apiDomainPerms {
multiStatusEntries = append(
multiStatusEntries,
p.importOrUpdateDomainPerm(
ctx,
permissionType,
account,
apiDomainPerm,
),
)
domainPerm, _, errWithCode = p.DomainPermissionCreate(
ctx,
permissionType,
account,
domain,
obfuscate,
publicComment,
privateComment,
subscriptionID,
)
var entry *apimodel.MultiStatusEntry
if errWithCode != nil {
entry = &apimodel.MultiStatusEntry{
// Use the failed domain entry as the resource value.
Resource: domain,
Message: errWithCode.Safe(),
Status: errWithCode.Code(),
}
} else {
entry = &apimodel.MultiStatusEntry{
// Use successfully created API model domain block as the resource value.
Resource: domainPerm,
Message: http.StatusText(http.StatusOK),
Status: http.StatusOK,
}
}
multiStatusEntries = append(multiStatusEntries, *entry)
}
return apimodel.NewMultiStatus(multiStatusEntries), nil
}
func (p *Processor) importOrUpdateDomainPerm(
ctx context.Context,
permType gtsmodel.DomainPermissionType,
account *gtsmodel.Account,
apiDomainPerm *apimodel.DomainPermission,
) apimodel.MultiStatusEntry {
var (
domain = apiDomainPerm.Domain.Domain
obfuscate = apiDomainPerm.Obfuscate
publicComment = cmp.Or(apiDomainPerm.PublicComment, apiDomainPerm.Comment)
privateComment = apiDomainPerm.PrivateComment
subscriptionID = "" // No sub ID for imports.
)
// Check if this domain
// perm already exists.
var (
domainPerm gtsmodel.DomainPermission
err error
)
if permType == gtsmodel.DomainPermissionBlock {
domainPerm, err = p.state.DB.GetDomainBlock(ctx, domain)
} else {
domainPerm, err = p.state.DB.GetDomainAllow(ctx, domain)
}
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real db error.
return apimodel.MultiStatusEntry{
Resource: domain,
Message: "db error checking for existence of domain permission",
Status: http.StatusInternalServerError,
}
}
var errWithCode gtserror.WithCode
if domainPerm != nil {
// Permission already exists, update it.
apiDomainPerm, errWithCode = p.DomainPermissionUpdate(
ctx,
permType,
domainPerm.GetID(),
obfuscate,
publicComment,
privateComment,
nil,
)
} else {
// Permission didn't exist yet, create it.
apiDomainPerm, _, errWithCode = p.DomainPermissionCreate(
ctx,
permType,
account,
domain,
util.PtrOrZero(obfuscate),
util.PtrOrZero(publicComment),
util.PtrOrZero(privateComment),
subscriptionID,
)
}
if errWithCode != nil {
return apimodel.MultiStatusEntry{
Resource: domain,
Message: errWithCode.Safe(),
Status: errWithCode.Code(),
}
}
return apimodel.MultiStatusEntry{
Resource: apiDomainPerm,
Message: http.StatusText(http.StatusOK),
Status: http.StatusOK,
}
}
// DomainPermissionsGet returns all existing domain
// permissions of the requested type. If export is
// true, the format will be suitable for writing out

View file

@ -106,9 +106,9 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
}
domains = append(domains, &apimodel.Domain{
Domain: d,
SuspendedAt: util.FormatISO8601(domainBlock.CreatedAt),
PublicComment: domainBlock.PublicComment,
Domain: d,
SuspendedAt: util.FormatISO8601(domainBlock.CreatedAt),
Comment: &domainBlock.PublicComment,
})
}
}

View file

@ -438,7 +438,7 @@ func (s *Subscriptions) processDomainPermission(
Obfuscate: wantedPerm.GetObfuscate(),
SubscriptionID: permSub.ID,
}
insertF = func() error { return s.state.DB.CreateDomainBlock(ctx, domainBlock) }
insertF = func() error { return s.state.DB.PutDomainBlock(ctx, domainBlock) }
action = &gtsmodel.AdminAction{
ID: id.NewULID(),
@ -461,7 +461,7 @@ func (s *Subscriptions) processDomainPermission(
Obfuscate: wantedPerm.GetObfuscate(),
SubscriptionID: permSub.ID,
}
insertF = func() error { return s.state.DB.CreateDomainAllow(ctx, domainAllow) }
insertF = func() error { return s.state.DB.PutDomainAllow(ctx, domainAllow) }
action = &gtsmodel.AdminAction{
ID: id.NewULID(),
@ -564,13 +564,13 @@ func permsFromCSV(
for i, columnHeader := range columnHeaders {
// Remove leading # if present.
normal := strings.TrimLeft(columnHeader, "#")
columnHeader = strings.TrimLeft(columnHeader, "#")
// Find index of each column header we
// care about, ensuring no duplicates.
switch normal {
switch {
case "domain":
case columnHeader == "domain":
if domainI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate domain column header in csv: %+v", columnHeaders)
@ -578,7 +578,7 @@ func permsFromCSV(
}
domainI = &i
case "severity":
case columnHeader == "severity":
if severityI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate severity column header in csv: %+v", columnHeaders)
@ -586,15 +586,15 @@ func permsFromCSV(
}
severityI = &i
case "public_comment":
case columnHeader == "public_comment" || columnHeader == "comment":
if publicCommentI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate public_comment column header in csv: %+v", columnHeaders)
err := gtserror.NewfAt(3, "duplicate public_comment or comment column header in csv: %+v", columnHeaders)
return nil, err
}
publicCommentI = &i
case "obfuscate":
case columnHeader == "obfuscate":
if obfuscateI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate obfuscate column header in csv: %+v", columnHeaders)
@ -674,15 +674,15 @@ func permsFromCSV(
perm.SetPublicComment(record[*publicCommentI])
}
var obfuscate bool
if obfuscateI != nil {
obfuscate, err := strconv.ParseBool(record[*obfuscateI])
obfuscate, err = strconv.ParseBool(record[*obfuscateI])
if err != nil {
l.Warnf("couldn't parse obfuscate field of record: %+v", record)
continue
}
perm.SetObfuscate(&obfuscate)
}
perm.SetObfuscate(&obfuscate)
// We're done.
perms = append(perms, perm)
@ -742,8 +742,9 @@ func permsFromJSON(
}
// Set remaining fields.
perm.SetPublicComment(apiPerm.PublicComment)
perm.SetObfuscate(&apiPerm.Obfuscate)
publicComment := cmp.Or(apiPerm.PublicComment, apiPerm.Comment)
perm.SetPublicComment(util.PtrOrZero(publicComment))
perm.SetObfuscate(util.Ptr(util.PtrOrZero(apiPerm.Obfuscate)))
// We're done.
perms = append(perms, perm)
@ -792,9 +793,15 @@ func permsFromPlain(
var perm gtsmodel.DomainPermission
switch permType {
case gtsmodel.DomainPermissionBlock:
perm = &gtsmodel.DomainBlock{Domain: domain}
perm = &gtsmodel.DomainBlock{
Domain: domain,
Obfuscate: util.Ptr(false),
}
case gtsmodel.DomainPermissionAllow:
perm = &gtsmodel.DomainAllow{Domain: domain}
perm = &gtsmodel.DomainAllow{
Domain: domain,
Obfuscate: util.Ptr(false),
}
}
// We're done.

View file

@ -775,7 +775,7 @@ func (suite *SubscriptionsTestSuite) TestAdoption() {
existingBlock2,
existingBlock3,
} {
if err := testStructs.State.DB.CreateDomainBlock(
if err := testStructs.State.DB.PutDomainBlock(
ctx, block,
); err != nil {
suite.FailNow(err.Error())
@ -876,7 +876,7 @@ func (suite *SubscriptionsTestSuite) TestDomainAllowsAndBlocks() {
}
// Store existing allow.
if err := testStructs.State.DB.CreateDomainAllow(ctx, existingAllow); err != nil {
if err := testStructs.State.DB.PutDomainAllow(ctx, existingAllow); err != nil {
suite.FailNow(err.Error())
}

View file

@ -2181,7 +2181,7 @@ func (c *Converter) DomainPermToAPIDomainPerm(
domainPerm := &apimodel.DomainPermission{
Domain: apimodel.Domain{
Domain: domain,
PublicComment: d.GetPublicComment(),
PublicComment: util.Ptr(d.GetPublicComment()),
},
}
@ -2192,8 +2192,8 @@ func (c *Converter) DomainPermToAPIDomainPerm(
}
domainPerm.ID = d.GetID()
domainPerm.Obfuscate = util.PtrOrZero(d.GetObfuscate())
domainPerm.PrivateComment = d.GetPrivateComment()
domainPerm.Obfuscate = d.GetObfuscate()
domainPerm.PrivateComment = util.Ptr(d.GetPrivateComment())
domainPerm.SubscriptionID = d.GetSubscriptionID()
domainPerm.CreatedBy = d.GetCreatedByAccountID()
if createdAt := d.GetCreatedAt(); !createdAt.IsZero() {

View file

@ -627,7 +627,7 @@ nothanks.com`
{
"domain": "bumfaces.net",
"suspended_at": "2020-05-13T13:29:12.000Z",
"public_comment": "big jerks"
"comment": "big jerks"
},
{
"domain": "peepee.poopoo",

View file

@ -65,17 +65,20 @@ db.QueryRow(`SELECT sqlite_version()`).Scan(&version)
This module replaces the SQLite [OS Interface](https://sqlite.org/vfs.html)
(aka VFS) with a [pure Go](vfs/) implementation,
which has advantages and disadvantages.
Read more about the Go VFS design [here](vfs/README.md).
Because each database connection executes within a Wasm sandboxed environment,
memory usage will be higher than alternatives.
### Testing
This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3/wiki/Test-coverage-report).
It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing.
[wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach)
thorough testing.
Every commit is [tested](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix) on
Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (amd64/arm64),
Linux (amd64/arm64/386/riscv64/ppc64le/s390x), macOS (arm64/amd64),
Windows (amd64), FreeBSD (amd64/arm64), OpenBSD (amd64), NetBSD (amd64/arm64),
DragonFly BSD (amd64), illumos (amd64), and Solaris (amd64).
@ -84,12 +87,21 @@ The Go VFS is tested by running SQLite's
### Performance
Perfomance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
Performance of the [`database/sql`](https://pkg.go.dev/database/sql) driver is
[competitive](https://github.com/cvilsmeier/go-sqlite-bench) with alternatives.
The Wasm and VFS layers are also tested by running SQLite's
The Wasm and VFS layers are also benchmarked by running SQLite's
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Concurrency
This module behaves similarly to SQLite in [multi-thread](https://sqlite.org/threadsafe.html) mode:
it is goroutine-safe, provided that no single database connection, or object derived from it,
is used concurrently by multiple goroutines.
The [`database/sql`](https://pkg.go.dev/database/sql) API is safe to use concurrently,
according to its documentation.
### FAQ, issues, new features
For questions, please see [Discussions](https://github.com/ncruces/go-sqlite3/discussions/categories/q-a).

View file

@ -31,6 +31,10 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
if c.interrupt.Err() != nil {
return nil, INTERRUPT
}
defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
@ -42,7 +46,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1
}
c.checkInterrupt(c.handle)
rc := res_t(c.call("sqlite3_blob_open", stk_t(c.handle),
stk_t(dbPtr), stk_t(tablePtr), stk_t(columnPtr),
stk_t(row), stk_t(flags), stk_t(blobPtr)))
@ -253,7 +256,9 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
b.c.checkInterrupt(b.c.handle)
if b.c.interrupt.Err() != nil {
return INTERRUPT
}
err := b.c.error(res_t(b.c.call("sqlite3_blob_reopen", stk_t(b.handle), stk_t(row))))
b.bytes = int64(int32(b.c.call("sqlite3_blob_bytes", stk_t(b.handle))))
b.offset = 0

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strconv"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
@ -48,6 +49,15 @@ func (c *Conn) Config(op DBConfig, arg ...bool) (bool, error) {
return util.ReadBool(c.mod, argsPtr), c.error(rc)
}
var defaultLogger atomic.Pointer[func(code ExtendedErrorCode, msg string)]
// ConfigLog sets up the default error logging callback for new connections.
//
// https://sqlite.org/errlog.html
func ConfigLog(cb func(code ExtendedErrorCode, msg string)) {
defaultLogger.Store(&cb)
}
// ConfigLog sets up the error logging callback for the connection.
//
// https://sqlite.org/errlog.html
@ -265,6 +275,10 @@ func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pAr
//
// https://sqlite.org/c3ref/wal_checkpoint_v2.html
func (c *Conn) WALCheckpoint(schema string, mode CheckpointMode) (nLog, nCkpt int, err error) {
if c.interrupt.Err() != nil {
return 0, 0, INTERRUPT
}
defer c.arena.mark()()
nLogPtr := c.arena.new(ptrlen)
nCkptPtr := c.arena.new(ptrlen)
@ -378,6 +392,6 @@ func (c *Conn) EnableChecksums(schema string) error {
}
// Checkpoint the WAL.
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_RESTART)
_, _, err = c.WALCheckpoint(schema, CHECKPOINT_FULL)
return err
}

View file

@ -25,7 +25,6 @@ type Conn struct {
*sqlite
interrupt context.Context
pending *Stmt
stmts []*Stmt
busy func(context.Context, int) bool
log func(xErrorCode, string)
@ -41,6 +40,7 @@ type Conn struct {
busylst time.Time
arena arena
handle ptr_t
gosched uint8
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
@ -49,7 +49,7 @@ func Open(filename string) (*Conn, error) {
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connectiton.
// which is used to interrupt the process of opening the connection.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
@ -92,6 +92,9 @@ func newConn(ctx context.Context, filename string, flags OpenFlag) (ret *Conn, _
}()
c.ctx = context.WithValue(c.ctx, connKey{}, c)
if logger := defaultLogger.Load(); logger != nil {
c.ConfigLog(*logger)
}
c.arena = c.newArena()
c.handle, err = c.openDB(filename, flags)
if err == nil {
@ -117,7 +120,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
return 0, err
}
c.call("sqlite3_progress_handler_go", stk_t(handle), 100)
c.call("sqlite3_progress_handler_go", stk_t(handle), 1000)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok {
@ -129,7 +132,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (ptr_t, error) {
}
}
if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String())
rc := res_t(c.call("sqlite3_exec", stk_t(handle), stk_t(pragmaPtr), 0, 0, 0))
if err := c.sqlite.error(rc, handle, pragmas.String()); err != nil {
@ -163,9 +165,6 @@ func (c *Conn) Close() error {
return nil
}
c.pending.Close()
c.pending = nil
rc := res_t(c.call("sqlite3_close", stk_t(c.handle)))
if err := c.error(rc); err != nil {
return err
@ -180,11 +179,16 @@ func (c *Conn) Close() error {
//
// https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
defer c.arena.mark()()
sqlPtr := c.arena.string(sql)
if c.interrupt.Err() != nil {
return INTERRUPT
}
return c.exec(sql)
}
c.checkInterrupt(c.handle)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(sqlPtr), 0, 0, 0))
func (c *Conn) exec(sql string) error {
defer c.arena.mark()()
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_exec", stk_t(c.handle), stk_t(textPtr), 0, 0, 0))
return c.error(rc, sql)
}
@ -203,20 +207,22 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
if len(sql) > _MAX_SQL_LENGTH {
return nil, "", TOOBIG
}
if c.interrupt.Err() != nil {
return nil, "", INTERRUPT
}
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
textPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
rc := res_t(c.call("sqlite3_prepare_v3", stk_t(c.handle),
stk_t(sqlPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(textPtr), stk_t(len(sql)+1), stk_t(flags),
stk_t(stmtPtr), stk_t(tailPtr)))
stmt = &Stmt{c: c}
stmt = &Stmt{c: c, sql: sql}
stmt.handle = util.Read32[ptr_t](c.mod, stmtPtr)
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-sqlPtr:]; sql != "" {
if sql := sql[util.Read32[ptr_t](c.mod, tailPtr)-textPtr:]; sql != "" {
tail = sql
}
@ -337,43 +343,17 @@ func (c *Conn) GetInterrupt() context.Context {
//
// https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
if ctx == nil {
panic("nil Context")
}
old = c.interrupt
c.interrupt = ctx
if ctx == old || ctx.Done() == old.Done() {
return old
}
// A busy SQL statement prevents SQLite from ignoring an interrupt
// that comes before any other statements are started.
if c.pending == nil {
defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", stk_t(c.handle), stk_t(loopPtr), math.MaxUint64,
stk_t(PREPARE_PERSISTENT), stk_t(stmtPtr), 0)
c.pending = &Stmt{c: c}
c.pending.handle = util.Read32[ptr_t](c.mod, stmtPtr)
}
if old.Done() != nil && ctx.Err() == nil {
c.pending.Reset()
}
if ctx.Done() != nil {
c.pending.Step()
}
return old
}
func (c *Conn) checkInterrupt(handle ptr_t) {
if c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", stk_t(handle))
}
}
func progressCallback(ctx context.Context, mod api.Module, _ ptr_t) (interrupt int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok {
if c.interrupt.Done() != nil {
if c.gosched++; c.gosched%16 == 0 {
runtime.Gosched()
}
if c.interrupt.Err() != nil {
@ -429,11 +409,8 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool))
func busyCallback(ctx context.Context, mod api.Module, pDB ptr_t, count int32) (retry int32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
interrupt := c.interrupt
if interrupt == nil {
interrupt = context.Background()
}
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
if interrupt := c.interrupt; interrupt.Err() == nil &&
c.busy(interrupt, int(count)) {
retry = 1
}
}

View file

@ -11,10 +11,9 @@ const (
_ROW = 100 /* sqlite3_step() has another row ready */
_DONE = 101 /* sqlite3_step() has finished executing */
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
_MAX_FUNCTION_ARG = 100
_MAX_NAME = 1e6 // Self-imposed limit for most NUL terminated strings.
_MAX_LENGTH = 1e9
_MAX_SQL_LENGTH = 1e9
ptrlen = util.PtrLen
intlen = util.IntLen

View file

@ -89,20 +89,26 @@ func (ctx Context) ResultText(value string) {
}
// ResultRawText sets the text result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultRawText(value []byte) {
if len(value) == 0 {
ctx.ResultText("")
return
}
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_text_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))
}
// ResultBlob sets the result of the function to a []byte.
// Returning a nil slice is the same as calling [Context.ResultNull].
//
// https://sqlite.org/c3ref/result_blob.html
func (ctx Context) ResultBlob(value []byte) {
if len(value) == 0 {
ctx.ResultZeroBlob(0)
return
}
ptr := ctx.c.newBytes(value)
ctx.c.call("sqlite3_result_blob_go",
stk_t(ctx.handle), stk_t(ptr), stk_t(len(value)))

View file

@ -20,22 +20,45 @@
// - a [serializable] transaction is always "immediate";
// - a [read-only] transaction is always "deferred".
//
// # Datatypes In SQLite
//
// SQLite is dynamically typed.
// Columns can mostly hold any value regardless of their declared type.
// SQLite supports most [driver.Value] types out of the box,
// but bool and [time.Time] require special care.
//
// Booleans can be stored on any column type and scanned back to a *bool.
// However, if scanned to a *any, booleans may either become an
// int64, string or bool, depending on the declared type of the column.
// If you use BOOLEAN for your column type,
// 1 and 0 will always scan as true and false.
//
// # Working with time
//
// Time values can similarly be stored on any column type.
// The time encoding/decoding format can be specified using "_timefmt":
//
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
//
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
// Special values are: "auto" (the default), "sqlite", "rfc3339";
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// - "rfc3339" encodes and decodes RFC 3339 only.
//
// If you encode as RFC 3339 (the default),
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
//
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
// If you encode as RFC 3339 (the default),
// consider using the TIME [collating sequence] to produce time-ordered sequences.
//
// If you encode as RFC 3339 (the default),
// time values will scan back to a *time.Time unless your column type is TEXT.
// Otherwise, if scanned to a *any, time values may either become an
// int64, float64 or string, depending on the time format and declared type of the column.
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
// "_timefmt" will be used to decode values.
//
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
//
// When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner].
@ -48,7 +71,7 @@
// The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, or nil,
// depending on the column type and what whoever wrote the value.
// depending on the column type and whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help.
//
// # Setting PRAGMAs
@ -358,13 +381,10 @@ func (c *conn) Commit() error {
}
func (c *conn) Rollback() error {
err := c.Conn.Exec(`ROLLBACK` + c.txReset)
if errors.Is(err, sqlite3.INTERRUPT) {
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
err = c.Conn.Exec(`ROLLBACK` + c.txReset)
}
return err
// ROLLBACK even if interrupted.
old := c.Conn.SetInterrupt(context.Background())
defer c.Conn.SetInterrupt(old)
return c.Conn.Exec(`ROLLBACK` + c.txReset)
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
@ -598,6 +618,28 @@ const (
_TIME
)
func scanFromDecl(decl string) scantype {
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch decl {
case "INT", "INTEGER":
return _INT
case "REAL":
return _REAL
case "TEXT":
return _TEXT
case "BLOB":
return _BLOB
case "BOOLEAN":
return _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
return _TIME
}
return _ANY
}
var (
// Ensure these interfaces are implemented:
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@ -622,6 +664,18 @@ func (r *rows) Columns() []string {
return r.names
}
func (r *rows) scanType(index int) scantype {
if r.scans == nil {
count := r.Stmt.ColumnCount()
scans := make([]scantype, count)
for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
}
r.scans = scans
}
return r.scans[index]
}
func (r *rows) loadColumnMetadata() {
if r.nulls == nil {
count := r.Stmt.ColumnCount()
@ -635,24 +689,7 @@ func (r *rows) loadColumnMetadata() {
r.Stmt.ColumnTableName(i),
col)
types[i] = strings.ToUpper(types[i])
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch types[i] {
case "INT", "INTEGER":
scans[i] = _INT
case "REAL":
scans[i] = _REAL
case "TEXT":
scans[i] = _TEXT
case "BLOB":
scans[i] = _BLOB
case "BOOLEAN":
scans[i] = _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
scans[i] = _TIME
}
scans[i] = scanFromDecl(types[i])
}
}
r.nulls = nulls
@ -661,27 +698,15 @@ func (r *rows) loadColumnMetadata() {
}
}
func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
types := make([]string, count)
for i := range types {
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
}
r.types = types
}
return r.types[index]
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
r.loadColumnMetadata()
decltype := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
decltype = decltype[:i]
decl := r.types[index]
if len := len(decl); len > 0 && decl[len-1] == ')' {
if i := strings.LastIndexByte(decl, '('); i >= 0 {
decl = decl[:i]
}
}
return strings.TrimSpace(decltype)
return strings.TrimSpace(decl)
}
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
@ -748,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error {
}
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
err := r.Stmt.Columns(data...)
if err := r.Stmt.ColumnsRaw(data...); err != nil {
return err
}
for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
}
}
return err
}
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
switch v := v.(type) {
case int64, float64:
// could be a time value
case string:
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
scan := r.scanType(i)
switch v := dest[i].(type) {
case int64:
if scan == _BOOL {
switch v {
case 1:
dest[i] = true
case 0:
dest[i] = false
}
continue
}
case []byte:
if len(v) == cap(v) { // a BLOB
continue
}
if scan != _TEXT {
switch r.tmWrite {
case "", time.RFC3339, time.RFC3339Nano:
t, ok := maybeTime(v)
if ok {
dest[i] = t
continue
}
}
}
dest[i] = string(v)
case float64:
break
default:
continue
}
t, ok := maybeTime(v)
if ok {
return t, true
if scan == _TIME {
t, err := r.tmRead.Decode(dest[i])
if err == nil {
dest[i] = t
continue
}
}
default:
return
}
switch r.declType(i) {
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// could be a time value
default:
return
}
t, err := r.tmRead.Decode(v)
return t, err == nil
return nil
}

View file

@ -1,12 +1,15 @@
package driver
import "time"
import (
"bytes"
"time"
)
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) (_ time.Time, _ bool) {
func maybeTime(text []byte) (_ time.Time, _ bool) {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) {
// Slow path.
var buf [len(time.RFC3339Nano)]byte
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
return date, true
}
return

View file

@ -2,7 +2,6 @@ package sqlite3
import (
"errors"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
@ -12,7 +11,6 @@ import (
//
// https://sqlite.org/c3ref/errcode.html
type Error struct {
str string
msg string
sql string
code res_t
@ -29,19 +27,13 @@ func (e *Error) Code() ErrorCode {
//
// https://sqlite.org/rescode.html
func (e *Error) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e.code)
return xErrorCode(e.code)
}
// Error implements the error interface.
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("sqlite3: ")
if e.str != "" {
b.WriteString(e.str)
} else {
b.WriteString(strconv.Itoa(int(e.code)))
}
b.WriteString(util.ErrorCodeString(uint32(e.code)))
if e.msg != "" {
b.WriteString(": ")
@ -103,12 +95,12 @@ func (e ErrorCode) Error() string {
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY
return e == BUSY || e == INTERRUPT
}
// ExtendedCode returns the extended error code for this error.
func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e)
return xErrorCode(e)
}
// Error implements the error interface.
@ -133,7 +125,7 @@ func (e ExtendedErrorCode) As(err any) bool {
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY
return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
}
// Timeout returns true for [BUSY_TIMEOUT] errors.

View file

@ -3,7 +3,9 @@ package sqlite3
import (
"context"
"io"
"iter"
"sync"
"sync/atomic"
"github.com/tetratelabs/wazero/api"
@ -45,7 +47,7 @@ func (c Conn) AnyCollationNeeded() error {
// CreateCollation defines a new collating sequence.
//
// https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
@ -57,6 +59,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
return c.error(rc)
}
// Collating function is the type of a collation callback.
// Implementations must not retain a or b.
type CollatingFunction func(a, b []byte) int
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
@ -77,34 +83,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
// Implementations must not retain arg.
type ScalarFunction func(ctx Context, arg ...Value)
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
// CreateAggregateFunction defines a new aggregate SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
call := "sqlite3_create_aggregate_function_go"
if fn != nil {
agg := fn()
if c, ok := agg.(io.Closer); ok {
if err := c.Close(); err != nil {
return err
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
var a aggregateFunc
coro := func(yieldCoro func(struct{}) bool) {
seq := func(yieldSeq func([]Value) bool) {
for yieldSeq(a.arg) {
if !yieldCoro(struct{}{}) {
break
}
}
}
fn(&a.ctx, seq)
}
}
if _, ok := agg.(WindowFunction); ok {
call = "sqlite3_create_window_function_go"
}
funcPtr = util.AddHandle(c.ctx, fn)
a.next, a.stop = iter.Pull(coro)
return &a
}))
}
rc := res_t(c.call(call,
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateSeqFunction is the type of an aggregate SQL function.
// Implementations must not retain the slices yielded by seq.
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
var funcPtr ptr_t
defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
agg := fn()
if win, ok := agg.(WindowFunction); ok {
return win
}
return windowFunc{agg, name}
}))
}
rc := res_t(c.call("sqlite3_create_window_function_go",
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
stk_t(flag), stk_t(funcPtr)))
return c.error(rc)
}
// AggregateConstructor is a an [AggregateFunction] constructor.
type AggregateConstructor func() AggregateFunction
// AggregateFunction is the interface an aggregate function should implement.
//
// https://sqlite.org/appfunc.html
@ -153,26 +192,24 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
}
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
}
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
callbackArgs(db, args[:nArg], pArg)
fn(Context{db, pCtx}, args[:nArg]...)
fn(Context{db, pCtx}, *args...)
}
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
callbackArgs(db, args[:nArg], pArg)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn, _ := callbackAggregate(db, pAgg, pApp)
fn.Step(Context{db, pCtx}, args[:nArg]...)
fn.Step(Context{db, pCtx}, *args...)
}
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) {
@ -196,12 +233,11 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t,
}
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) {
args := getFuncArgs()
defer putFuncArgs(args)
db := ctx.Value(connKey{}).(*Conn)
callbackArgs(db, args[:nArg], pArg)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
fn.Inverse(Context{db, pCtx}, args[:nArg]...)
fn.Inverse(Context{db, pCtx}, *args...)
}
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
@ -211,7 +247,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
}
// We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.Write32(db.mod, pAgg, handle)
@ -220,25 +256,64 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
return fn, 0
}
func callbackArgs(db *Conn, arg []Value, pArg ptr_t) {
for i := range arg {
arg[i] = Value{
var (
valueArgsPool sync.Pool
valueArgsLen atomic.Int32
)
func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value {
arg, ok := valueArgsPool.Get().(*[]Value)
if !ok || cap(*arg) < int(nArg) {
max := valueArgsLen.Or(nArg) | nArg
lst := make([]Value, max)
arg = &lst
}
lst := (*arg)[:nArg]
for i := range lst {
lst[i] = Value{
c: db,
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen),
}
}
*arg = lst
return arg
}
var funcArgsPool sync.Pool
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
funcArgsPool.Put(p)
func returnArgs(p *[]Value) {
valueArgsPool.Put(p)
}
func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
if p := funcArgsPool.Get(); p == nil {
return new([_MAX_FUNCTION_ARG]Value)
} else {
return p.(*[_MAX_FUNCTION_ARG]Value)
type aggregateFunc struct {
next func() (struct{}, bool)
stop func()
ctx Context
arg []Value
}
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
a.ctx = ctx
a.arg = append(a.arg[:0], arg...)
if _, more := a.next(); !more {
a.stop()
}
}
func (a *aggregateFunc) Value(ctx Context) {
a.ctx = ctx
a.stop()
}
func (a *aggregateFunc) Close() error {
a.stop()
return nil
}
type windowFunc struct {
AggregateFunction
name string
}
func (w windowFunc) Inverse(ctx Context, arg ...Value) {
// Implementing inverse allows certain queries that don't really need it to succeed.
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
}

View file

@ -75,7 +75,7 @@ func ErrorCodeString(rc uint32) string {
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
case EMPTY:
break
case SCHEMA:
return "sqlite3: database schema has changed"
@ -91,7 +91,7 @@ func ErrorCodeString(rc uint32) string {
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
case FORMAT:
break
case RANGE:
return "sqlite3: column index out of range"

View file

@ -135,11 +135,10 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
if i := bytes.IndexByte(buf, 0); i >= 0 {
return string(buf[:i])
}
panic(NoNulErr)
}
func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) {

View file

@ -120,33 +120,33 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
panic(util.OOMErr)
}
if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 {
err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME)
}
if handle != 0 {
var msg, query string
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
switch {
case msg == "not an error":
msg = ""
case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
msg = ""
}
}
if len(sql) != 0 {
if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
err.sql = sql[0][i:]
query = sql[0][i:]
}
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
if msg != "" || query != "" {
return &Error{code: rc, msg: msg, sql: query}
}
}
return &err
return xErrorCode(rc)
}
func (sqlt *sqlite) getfn(name string) api.Function {
@ -212,14 +212,10 @@ func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
}
func (sqlt *sqlite) newBytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
if len(b) == 0 {
return 0
}
size := len(b)
if size == 0 {
size = 1
}
ptr := sqlt.new(int64(size))
ptr := sqlt.new(int64(len(b)))
util.WriteBytes(sqlt.mod, ptr, b)
return ptr
}
@ -288,7 +284,7 @@ func (a *arena) new(size int64) ptr_t {
}
func (a *arena) bytes(b []byte) ptr_t {
if (*[0]byte)(b) == nil {
if len(b) == 0 {
return 0
}
ptr := a.new(int64(len(b)))

View file

@ -106,7 +106,14 @@ func (s *Stmt) Busy() bool {
//
// https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool {
s.c.checkInterrupt(s.c.handle)
if s.c.interrupt.Err() != nil {
s.err = INTERRUPT
return false
}
return s.step()
}
func (s *Stmt) step() bool {
rc := res_t(s.c.call("sqlite3_step", stk_t(s.handle)))
switch rc {
case _ROW:
@ -131,7 +138,11 @@ func (s *Stmt) Err() error {
// Exec is a convenience function that repeatedly calls [Stmt.Step] until it returns false,
// then calls [Stmt.Reset] to reset the statement and get any error that occurred.
func (s *Stmt) Exec() error {
for s.Step() {
if s.c.interrupt.Err() != nil {
return INTERRUPT
}
// TODO: implement this in C.
for s.step() {
}
return s.Reset()
}
@ -254,13 +265,15 @@ func (s *Stmt) BindText(param int, value string) error {
// BindRawText binds a []byte to the prepared statement as text.
// The leftmost SQL parameter has an index of 1.
// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindRawText(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
if len(value) == 0 {
return s.BindText(param, "")
}
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_text_go",
stk_t(s.handle), stk_t(param),
@ -270,13 +283,15 @@ func (s *Stmt) BindRawText(param int, value []byte) error {
// BindBlob binds a []byte to the prepared statement.
// The leftmost SQL parameter has an index of 1.
// Binding a nil slice is the same as calling [Stmt.BindNull].
//
// https://sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindBlob(param int, value []byte) error {
if len(value) > _MAX_LENGTH {
return TOOBIG
}
if len(value) == 0 {
return s.BindZeroBlob(param, 0)
}
ptr := s.c.newBytes(value)
rc := res_t(s.c.call("sqlite3_bind_blob_go",
stk_t(s.handle), stk_t(param),
@ -560,7 +575,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_text",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
return s.columnRawBytes(col, ptr, 1)
}
// ColumnRawBlob returns the value of the result column as a []byte.
@ -572,10 +587,10 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_blob",
stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr)
return s.columnRawBytes(col, ptr, 0)
}
func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte {
if ptr == 0 {
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
if rc != _ROW && rc != _DONE {
@ -586,7 +601,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
n := int32(s.c.call("sqlite3_column_bytes",
stk_t(s.handle), stk_t(col)))
return util.View(s.c.mod, ptr, int64(n))
return util.View(s.c.mod, ptr, int64(n+nul))[:n]
}
// ColumnJSON parses the JSON-encoded value of the result column
@ -633,22 +648,12 @@ func (s *Stmt) ColumnValue(col int) Value {
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] as string, and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest ...any) error {
defer s.c.arena.mark()()
count := int64(len(dest))
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if err := s.c.error(rc); err != nil {
types, ptr, err := s.columns(int64(len(dest)))
if err != nil {
return err
}
types := util.View(s.c.mod, typePtr, count)
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
@ -657,26 +662,95 @@ func (s *Stmt) Columns(dest ...any) error {
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, dataPtr)
dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL):
dest[i] = nil
default:
ptr := util.Read32[ptr_t](s.c.mod, dataPtr+0)
if ptr == 0 {
dest[i] = []byte{}
continue
}
len := util.Read32[int32](s.c.mod, dataPtr+4)
buf := util.View(s.c.mod, ptr, int64(len))
if types[i] == byte(TEXT) {
case byte(TEXT):
len := util.Read32[int32](s.c.mod, ptr+4)
if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(len))
dest[i] = string(buf)
} else {
dest[i] = buf
dest[i] = ""
}
case byte(BLOB):
len := util.Read32[int32](s.c.mod, ptr+4)
if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(len))
tmp, _ := dest[i].([]byte)
dest[i] = append(tmp[:0], buf...)
} else {
dest[i], _ = dest[i].([]byte)
}
}
dataPtr += 8
ptr += 8
}
return nil
}
// ColumnsRaw populates result columns into the provided slice.
// The slice must have [Stmt.ColumnCount] length.
//
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) ColumnsRaw(dest ...any) error {
types, ptr, err := s.columns(int64(len(dest)))
if err != nil {
return err
}
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
}
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL):
dest[i] = nil
default:
len := util.Read32[int32](s.c.mod, ptr+4)
if len == 0 && types[i] == byte(BLOB) {
dest[i] = []byte{}
} else {
cap := len
if types[i] == byte(TEXT) {
cap++
}
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(cap))[:len]
dest[i] = buf
}
}
ptr += 8
}
return nil
}
func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) {
defer s.c.arena.mark()()
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if rc == res_t(MISUSE) {
return nil, 0, MISUSE
}
if err := s.c.error(rc); err != nil {
return nil, 0, err
}
return util.View(s.c.mod, typePtr, count), dataPtr, nil
}

View file

@ -2,7 +2,6 @@ package sqlite3
import (
"context"
"errors"
"math/rand"
"runtime"
"strconv"
@ -21,11 +20,13 @@ type Txn struct {
}
// Begin starts a deferred transaction.
// It panics if a transaction is in-progress.
// For nested transactions, use [Conn.Savepoint].
//
// https://sqlite.org/lang_transaction.html
func (c *Conn) Begin() Txn {
// BEGIN even if interrupted.
err := c.txnExecInterrupted(`BEGIN DEFERRED`)
err := c.exec(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
@ -120,7 +121,8 @@ func (tx Txn) Commit() error {
//
// https://sqlite.org/lang_transaction.html
func (tx Txn) Rollback() error {
return tx.c.txnExecInterrupted(`ROLLBACK`)
// ROLLBACK even if interrupted.
return tx.c.exec(`ROLLBACK`)
}
// Savepoint is a marker within a transaction
@ -143,7 +145,7 @@ func (c *Conn) Savepoint() Savepoint {
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.txnExecInterrupted(`SAVEPOINT ` + name)
err := c.exec(`SAVEPOINT ` + name)
if err != nil {
panic(err)
}
@ -199,7 +201,7 @@ func (s Savepoint) Release(errp *error) {
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
if err != nil {
panic(err)
}
@ -212,17 +214,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name)
}
func (c *Conn) txnExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
return s.c.exec(`ROLLBACK TO ` + s.name)
}
// TxnState determines the transaction state of a database.

View file

@ -1,16 +0,0 @@
//go:build !windows
package osutil
import (
"io/fs"
"os"
)
// OpenFile behaves the same as [os.OpenFile],
// except on Windows it sets [syscall.FILE_SHARE_DELETE].
//
// See: https://go.dev/issue/32088#issuecomment-502850674
func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}

View file

@ -1,115 +0,0 @@
package osutil
import (
"io/fs"
"os"
. "syscall"
"unsafe"
)
// OpenFile behaves the same as [os.OpenFile],
// except on Windows it sets [syscall.FILE_SHARE_DELETE].
//
// See: https://go.dev/issue/32088#issuecomment-502850674
func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT}
}
r, e := syscallOpen(name, flag|O_CLOEXEC, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
// syscallOpen is a copy of [syscall.Open]
// that uses [syscall.FILE_SHARE_DELETE].
//
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd Handle, err error) {
if len(path) == 0 {
return InvalidHandle, ERROR_FILE_NOT_FOUND
}
pathp, err := UTF16PtrFromString(path)
if err != nil {
return InvalidHandle, err
}
var access uint32
switch mode & (O_RDONLY | O_WRONLY | O_RDWR) {
case O_RDONLY:
access = GENERIC_READ
case O_WRONLY:
access = GENERIC_WRITE
case O_RDWR:
access = GENERIC_READ | GENERIC_WRITE
}
if mode&O_CREAT != 0 {
access |= GENERIC_WRITE
}
if mode&O_APPEND != 0 {
access &^= GENERIC_WRITE
access |= FILE_APPEND_DATA
}
sharemode := uint32(FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE)
var sa *SecurityAttributes
if mode&O_CLOEXEC == 0 {
sa = makeInheritSa()
}
var createmode uint32
switch {
case mode&(O_CREAT|O_EXCL) == (O_CREAT | O_EXCL):
createmode = CREATE_NEW
case mode&(O_CREAT|O_TRUNC) == (O_CREAT | O_TRUNC):
createmode = CREATE_ALWAYS
case mode&O_CREAT == O_CREAT:
createmode = OPEN_ALWAYS
case mode&O_TRUNC == O_TRUNC:
createmode = TRUNCATE_EXISTING
default:
createmode = OPEN_EXISTING
}
var attrs uint32 = FILE_ATTRIBUTE_NORMAL
if perm&S_IWRITE == 0 {
attrs = FILE_ATTRIBUTE_READONLY
if createmode == CREATE_ALWAYS {
const _ERROR_BAD_NETPATH = Errno(53)
// We have been asked to create a read-only file.
// If the file already exists, the semantics of
// the Unix open system call is to preserve the
// existing permissions. If we pass CREATE_ALWAYS
// and FILE_ATTRIBUTE_READONLY to CreateFile,
// and the file already exists, CreateFile will
// change the file permissions.
// Avoid that to preserve the Unix semantics.
h, e := CreateFile(pathp, access, sharemode, sa, TRUNCATE_EXISTING, FILE_ATTRIBUTE_NORMAL, 0)
switch e {
case ERROR_FILE_NOT_FOUND, _ERROR_BAD_NETPATH, ERROR_PATH_NOT_FOUND:
// File does not exist. These are the same
// errors as Errno.Is checks for ErrNotExist.
// Carry on to create the file.
default:
// Success or some different error.
return h, e
}
}
}
if createmode == OPEN_EXISTING && access == GENERIC_READ {
// Necessary for opening directory handles.
attrs |= FILE_FLAG_BACKUP_SEMANTICS
}
if mode&O_SYNC != 0 {
const _FILE_FLAG_WRITE_THROUGH = 0x80000000
attrs |= _FILE_FLAG_WRITE_THROUGH
}
if mode&O_NONBLOCK != 0 {
attrs |= FILE_FLAG_OVERLAPPED
}
return CreateFile(pathp, access, sharemode, sa, createmode, attrs, 0)
}
func makeInheritSa() *SecurityAttributes {
var sa SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
return &sa
}

View file

@ -1,33 +0,0 @@
package osutil
import (
"io/fs"
"os"
)
// FS implements [fs.FS], [fs.StatFS], and [fs.ReadFileFS]
// using package [os].
//
// This filesystem does not respect [fs.ValidPath] rules,
// and fails [testing/fstest.TestFS]!
//
// Still, it can be a useful tool to unify implementations
// that can access either the [os] filesystem or an [fs.FS].
// It's OK to use this to open files, but you should avoid
// opening directories, resolving paths, or walking the file system.
type FS struct{}
// Open implements [fs.FS].
func (FS) Open(name string) (fs.File, error) {
return OpenFile(name, os.O_RDONLY, 0)
}
// ReadFileFS implements [fs.StatFS].
func (FS) Stat(name string) (fs.FileInfo, error) {
return os.Stat(name)
}
// ReadFile implements [fs.ReadFileFS].
func (FS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}

View file

@ -1,2 +0,0 @@
// Package osutil implements operating system utilities.
package osutil

View file

@ -5,5 +5,5 @@ package sql3util
//
// https://sqlite.org/fileformat.html#pages
func ValidPageSize(s int) bool {
return 512 <= s && s <= 65536 && s&(s-1) == 0
return s&(s-1) == 0 && 512 <= s && s <= 65536
}

View file

@ -139,7 +139,7 @@ func (v Value) Blob(buf []byte) []byte {
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
return v.rawBytes(ptr)
return v.rawBytes(ptr, 1)
}
// RawBlob returns the value as a []byte.
@ -149,16 +149,16 @@ func (v Value) RawText() []byte {
// https://sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
return v.rawBytes(ptr)
return v.rawBytes(ptr, 0)
}
func (v Value) rawBytes(ptr ptr_t) []byte {
func (v Value) rawBytes(ptr ptr_t, nul int32) []byte {
if ptr == 0 {
return nil
}
n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
return util.View(v.c.mod, ptr, int64(n))
return util.View(v.c.mod, ptr, int64(n+nul))[:n]
}
// Pointer gets the pointer associated with this value,

View file

@ -6,22 +6,30 @@ It replaces the default SQLite VFS with a **pure Go** implementation,
and exposes [interfaces](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs#VFS)
that should allow you to implement your own [custom VFSes](#custom-vfses).
Since it is a from scratch reimplementation,
there are naturally some ways it deviates from the original.
See the [support matrix](https://github.com/ncruces/go-sqlite3/wiki/Support-matrix)
for the list of supported OS and CPU architectures.
The main differences are [file locking](#file-locking) and [WAL mode](#write-ahead-logging) support.
Since this is a from scratch reimplementation,
there are naturally some ways it deviates from the original.
It's also not as battle tested as the original.
The main differences to be aware of are
[file locking](#file-locking) and
[WAL mode](#write-ahead-logging) support.
### File Locking
POSIX advisory locks, which SQLite uses on Unix, are
[broken by design](https://github.com/sqlite/sqlite/blob/b74eb0/src/os_unix.c#L1073-L1161).
POSIX advisory locks,
which SQLite uses on [Unix](https://github.com/sqlite/sqlite/blob/5d60f4/src/os_unix.c#L13-L14),
are [broken by design](https://github.com/sqlite/sqlite/blob/5d60f4/src/os_unix.c#L1074-L1162).
Instead, on Linux and macOS, this package uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
This package can also use
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2),
albeit with reduced concurrency (`BEGIN IMMEDIATE` behaves like `BEGIN EXCLUSIVE`).
albeit with reduced concurrency (`BEGIN IMMEDIATE` behaves like `BEGIN EXCLUSIVE`,
[docs](https://sqlite.org/lang_transaction.html#immediate)).
BSD locks are the default on BSD and illumos,
but you can opt into them with the `sqlite3_flock` build tag.
@ -44,11 +52,11 @@ to check if your build supports file locking.
### Write-Ahead Logging
On Unix, this package may use `mmap` to implement
On Unix, this package uses `mmap` to implement
[shared-memory for the WAL-index](https://sqlite.org/wal.html#implementation_of_shared_memory_for_the_wal_index),
like SQLite.
On Windows, this package may use `MapViewOfFile`, like SQLite.
On Windows, this package uses `MapViewOfFile`, like SQLite.
You can also opt into a cross-platform, in-process, memory sharing implementation
with the `sqlite3_dotlk` build tag.
@ -63,6 +71,11 @@ you must disable connection pooling by calling
You can use [`vfs.SupportsSharedMemory`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs#SupportsSharedMemory)
to check if your build supports shared memory.
### Blocking Locks
On Windows and macOS, this package implements
[Wal-mode blocking locks](https://sqlite.org/src/doc/tip/doc/wal-lock.md).
### Batch-Atomic Write
On Linux, this package may support
@ -94,8 +107,10 @@ The VFS can be customized with a few build tags:
> [`unix-flock` VFS](https://sqlite.org/compile.html#enable_locking_style);
> `sqlite3_dotlk` builds are compatible with the
> [`unix-dotfile` VFS](https://sqlite.org/compile.html#enable_locking_style).
> If incompatible file locking is used, accessing databases concurrently with
> _other_ SQLite libraries will eventually corrupt data.
> [!CAUTION]
> Concurrently accessing databases using incompatible VFSes
> will eventually corrupt data.
### Custom VFSes

View file

@ -49,9 +49,7 @@ func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
n, err = c.File.ReadAt(p, off)
p = p[:n]
// SQLite is reading the header of a database file.
if c.isDB && off == 0 && len(p) >= 100 &&
bytes.HasPrefix(p, []byte("SQLite format 3\000")) {
if isHeader(c.isDB, p, off) {
c.init((*[100]byte)(p))
}
@ -67,9 +65,7 @@ func (c cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
}
func (c cksmFile) WriteAt(p []byte, off int64) (n int, err error) {
// SQLite is writing the first page of a database file.
if c.isDB && off == 0 && len(p) >= 100 &&
bytes.HasPrefix(p, []byte("SQLite format 3\000")) {
if isHeader(c.isDB, p, off) {
c.init((*[100]byte)(p))
}
@ -116,9 +112,11 @@ func (c cksmFile) fileControl(ctx context.Context, mod api.Module, op _FcntlOpco
c.inCkpt = true
case _FCNTL_CKPT_DONE:
c.inCkpt = false
}
if rc := vfsFileControlImpl(ctx, mod, c, op, pArg); rc != _NOTFOUND {
return rc
case _FCNTL_PRAGMA:
rc := vfsFileControlImpl(ctx, mod, c, op, pArg)
if rc != _NOTFOUND {
return rc
}
}
return vfsFileControlImpl(ctx, mod, c.File, op, pArg)
}
@ -135,6 +133,14 @@ func (f *cksmFlags) init(header *[100]byte) {
}
}
func isHeader(isDB bool, p []byte, off int64) bool {
check := sql3util.ValidPageSize(len(p))
if isDB {
check = off == 0 && len(p) >= 100
}
return check && bytes.HasPrefix(p, []byte("SQLite format 3\000"))
}
func cksmCompute(a []byte) (cksm [8]byte) {
var s1, s2 uint32
for len(a) >= 8 {

View file

@ -6,9 +6,8 @@ import (
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
"github.com/ncruces/go-sqlite3/util/osutil"
)
type vfsOS struct{}
@ -40,7 +39,7 @@ func (vfsOS) Delete(path string, syncDir bool) error {
if err != nil {
return err
}
if canSyncDirs && syncDir {
if isUnix && syncDir {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
@ -96,7 +95,7 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error
if name == nil {
f, err = os.CreateTemp(os.Getenv("SQLITE_TMPDIR"), "*.db")
} else {
f, err = osutil.OpenFile(name.String(), oflags, 0666)
f, err = os.OpenFile(name.String(), oflags, 0666)
}
if err != nil {
if name == nil {
@ -118,15 +117,17 @@ func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error
return nil, flags, _IOERR_FSTAT
}
}
if flags&OPEN_DELETEONCLOSE != 0 {
if isUnix && flags&OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
file := vfsFile{
File: f,
psow: true,
atomic: osBatchAtomic(f),
readOnly: flags&OPEN_READONLY != 0,
syncDir: canSyncDirs && isCreate && isJournl,
syncDir: isUnix && isCreate && isJournl,
delete: !isUnix && flags&OPEN_DELETEONCLOSE != 0,
shm: NewSharedMemory(name.String()+"-shm", flags),
}
return &file, flags, nil
@ -139,6 +140,8 @@ type vfsFile struct {
readOnly bool
keepWAL bool
syncDir bool
atomic bool
delete bool
psow bool
}
@ -152,6 +155,9 @@ var (
)
func (f *vfsFile) Close() error {
if f.delete {
defer os.Remove(f.Name())
}
if f.shm != nil {
f.shm.Close()
}
@ -175,7 +181,7 @@ func (f *vfsFile) Sync(flags SyncFlag) error {
if err != nil {
return err
}
if canSyncDirs && f.syncDir {
if isUnix && f.syncDir {
f.syncDir = false
d, err := os.Open(filepath.Dir(f.File.Name()))
if err != nil {
@ -200,12 +206,15 @@ func (f *vfsFile) SectorSize() int {
func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
ret := IOCAP_SUBPAGE_READ
if osBatchAtomic(f.File) {
if f.atomic {
ret |= IOCAP_BATCH_ATOMIC
}
if f.psow {
ret |= IOCAP_POWERSAFE_OVERWRITE
}
if runtime.GOOS == "windows" {
ret |= IOCAP_UNDELETABLE_WHEN_OPEN
}
return ret
}
@ -214,6 +223,9 @@ func (f *vfsFile) SizeHint(size int64) error {
}
func (f *vfsFile) HasMoved() (bool, error) {
if runtime.GOOS == "windows" {
return false, nil
}
fi, err := f.Stat()
if err != nil {
return false, err

View file

@ -50,11 +50,15 @@ func osDowngradeLock(file *os.File, _ LockLevel) _ErrorCode {
}
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
err := unix.Flock(int(file.Fd()), unix.LOCK_UN)
if err != nil {
return _IOERR_UNLOCK
for {
err := unix.Flock(int(file.Fd()), unix.LOCK_UN)
if err == nil {
return _OK
}
if err != unix.EINTR {
return _IOERR_UNLOCK
}
}
return _OK
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
@ -89,13 +93,18 @@ func osLock(file *os.File, typ int16, start, len int64, def _ErrorCode) _ErrorCo
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), unix.F_SETLK, &unix.Flock_t{
lock := unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
for {
err := unix.FcntlFlock(file.Fd(), unix.F_SETLK, &lock)
if err == nil {
return _OK
}
if err != unix.EINTR {
return _IOERR_UNLOCK
}
}
}

View file

@ -27,7 +27,12 @@ func osSync(file *os.File, fullsync, _ /*dataonly*/ bool) error {
if fullsync {
return file.Sync()
}
return unix.Fsync(int(file.Fd()))
for {
err := unix.Fsync(int(file.Fd()))
if err != unix.EINTR {
return err
}
}
}
func osAllocate(file *os.File, size int64) error {
@ -85,13 +90,18 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{
lock := unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
for {
err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &lock)
if err == nil {
return _OK
}
if err != unix.EINTR {
return _IOERR_UNLOCK
}
}
}

View file

@ -3,6 +3,7 @@
package vfs
import (
"io"
"os"
"time"
@ -11,14 +12,36 @@ import (
func osSync(file *os.File, _ /*fullsync*/, _ /*dataonly*/ bool) error {
// SQLite trusts Linux's fdatasync for all fsync's.
return unix.Fdatasync(int(file.Fd()))
for {
err := unix.Fdatasync(int(file.Fd()))
if err != unix.EINTR {
return err
}
}
}
func osAllocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
return unix.Fallocate(int(file.Fd()), 0, 0, size)
for {
err := unix.Fallocate(int(file.Fd()), 0, 0, size)
if err == unix.EOPNOTSUPP {
break
}
if err != unix.EINTR {
return err
}
}
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
return file.Truncate(size)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
@ -37,22 +60,27 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
}
var err error
switch {
case timeout < 0:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock)
default:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
case timeout < 0:
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLKW, &lock)
}
return osLockErrorCode(err, def)
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{
lock := unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
for {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if err == nil {
return _OK
}
if err != unix.EINTR {
return _IOERR_UNLOCK
}
}
}

View file

@ -8,8 +8,8 @@ import (
)
const (
isUnix = false
_O_NOFOLLOW = 0
canSyncDirs = false
)
func osAccess(path string, flags AccessFlag) error {

View file

@ -10,8 +10,8 @@ import (
)
const (
isUnix = true
_O_NOFOLLOW = unix.O_NOFOLLOW
canSyncDirs = true
)
func osAccess(path string, flags AccessFlag) error {
@ -65,10 +65,15 @@ func osTestLock(file *os.File, start, len int64) (int16, _ErrorCode) {
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil {
return 0, _IOERR_CHECKRESERVEDLOCK
for {
err := unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock)
if err == nil {
return lock.Type, _OK
}
if err != unix.EINTR {
return 0, _IOERR_CHECKRESERVEDLOCK
}
}
return lock.Type, _OK
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {

View file

@ -135,12 +135,10 @@ func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _Error
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
switch {
case timeout == 0:
default:
err = osLockEx(file, flags|windows.LOCKFILE_FAIL_IMMEDIATELY, start, len)
case timeout < 0:
err = osLockEx(file, flags, start, len)
default:
err = osLockExTimeout(file, flags, start, len, timeout)
}
return osLockErrorCode(err, def)
}
@ -162,37 +160,6 @@ func osLockEx(file *os.File, flags, start, len uint32) error {
0, len, 0, &windows.Overlapped{Offset: start})
}
func osLockExTimeout(file *os.File, flags, start, len uint32, timeout time.Duration) error {
event, err := windows.CreateEvent(nil, 1, 0, nil)
if err != nil {
return err
}
defer windows.CloseHandle(event)
fd := windows.Handle(file.Fd())
overlapped := &windows.Overlapped{
Offset: start,
HEvent: event,
}
err = windows.LockFileEx(fd, flags, 0, len, 0, overlapped)
if err != windows.ERROR_IO_PENDING {
return err
}
ms := (timeout + time.Millisecond - 1) / time.Millisecond
rc, err := windows.WaitForSingleObject(event, uint32(ms))
if rc == windows.WAIT_OBJECT_0 {
return nil
}
defer windows.CancelIoEx(fd, overlapped)
if err != nil {
return err
}
return windows.Errno(rc)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK

View file

@ -68,16 +68,11 @@ func (s *vfsShm) Close() error {
panic(util.AssertErr())
}
func (s *vfsShm) shmOpen() _ErrorCode {
func (s *vfsShm) shmOpen() (rc _ErrorCode) {
if s.vfsShmParent != nil {
return _OK
}
var f *os.File
// Close file on error.
// Keep this here to avoid confusing checklocks.
defer func() { f.Close() }()
vfsShmListMtx.Lock()
defer vfsShmListMtx.Unlock()
@ -98,11 +93,16 @@ func (s *vfsShm) shmOpen() _ErrorCode {
}
// Always open file read-write, as it will be shared.
f, err = os.OpenFile(s.path,
f, err := os.OpenFile(s.path,
os.O_RDWR|os.O_CREATE|_O_NOFOLLOW, 0666)
if err != nil {
return _CANTOPEN
}
defer func() {
if rc != _OK {
f.Close()
}
}()
// Dead man's switch.
if lock, rc := osTestLock(f, _SHM_DMS, 1); rc != _OK {
@ -131,7 +131,6 @@ func (s *vfsShm) shmOpen() _ErrorCode {
File: f,
info: fi,
}
f = nil // Don't close the file.
for i, g := range vfsShmList {
if g == nil {
vfsShmList[i] = s.vfsShmParent

View file

@ -7,14 +7,11 @@ import (
"io"
"os"
"sync"
"syscall"
"time"
"github.com/tetratelabs/wazero/api"
"golang.org/x/sys/windows"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/osutil"
)
type vfsShm struct {
@ -33,8 +30,6 @@ type vfsShm struct {
sync.Mutex
}
var _ blockingSharedMemory = &vfsShm{}
func (s *vfsShm) Close() error {
// Unmap regions.
for _, r := range s.regions {
@ -48,8 +43,7 @@ func (s *vfsShm) Close() error {
func (s *vfsShm) shmOpen() _ErrorCode {
if s.File == nil {
f, err := osutil.OpenFile(s.path,
os.O_RDWR|os.O_CREATE|syscall.O_NONBLOCK, 0666)
f, err := os.OpenFile(s.path, os.O_RDWR|os.O_CREATE, 0666)
if err != nil {
return _CANTOPEN
}
@ -67,7 +61,7 @@ func (s *vfsShm) shmOpen() _ErrorCode {
return _IOERR_SHMOPEN
}
}
rc := osReadLock(s.File, _SHM_DMS, 1, time.Millisecond)
rc := osReadLock(s.File, _SHM_DMS, 1, 0)
s.fileLock = rc == _OK
return rc
}
@ -135,11 +129,6 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext
}
func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) (rc _ErrorCode) {
var timeout time.Duration
if s.blocking {
timeout = time.Millisecond
}
switch {
case flags&_SHM_LOCK != 0:
defer s.shmAcquire(&rc)
@ -151,9 +140,9 @@ func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) (rc _ErrorCode) {
case flags&_SHM_UNLOCK != 0:
return osUnlock(s.File, _SHM_BASE+uint32(offset), uint32(n))
case flags&_SHM_SHARED != 0:
return osReadLock(s.File, _SHM_BASE+uint32(offset), uint32(n), timeout)
return osReadLock(s.File, _SHM_BASE+uint32(offset), uint32(n), 0)
case flags&_SHM_EXCLUSIVE != 0:
return osWriteLock(s.File, _SHM_BASE+uint32(offset), uint32(n), timeout)
return osWriteLock(s.File, _SHM_BASE+uint32(offset), uint32(n), 0)
default:
panic(util.AssertErr())
}
@ -184,7 +173,3 @@ func (s *vfsShm) shmUnmap(delete bool) {
os.Remove(s.path)
}
}
func (s *vfsShm) shmEnableBlocking(block bool) {
s.blocking = block
}

View file

@ -79,9 +79,12 @@ func implements[T any](typ reflect.Type) bool {
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVTab(sql string) error {
if c.interrupt.Err() != nil {
return INTERRUPT
}
defer c.arena.mark()()
sqlPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(sqlPtr)))
textPtr := c.arena.string(sql)
rc := res_t(c.call("sqlite3_declare_vtab", stk_t(c.handle), stk_t(textPtr)))
return c.error(rc)
}
@ -162,6 +165,7 @@ type VTabDestroyer interface {
}
// A VTabUpdater allows a virtual table to be updated.
// Implementations must not retain arg.
type VTabUpdater interface {
VTab
// https://sqlite.org/vtab.html#xupdate
@ -241,6 +245,7 @@ type VTabSavepointer interface {
// to loop through the virtual table.
// A VTabCursor may optionally implement
// [io.Closer] to free resources.
// Implementations of Filter must not retain arg.
//
// https://sqlite.org/c3ref/vtab_cursor.html
type VTabCursor interface {
@ -489,12 +494,12 @@ func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo
}
func vtabUpdateCallback(ctx context.Context, mod api.Module, pVTab ptr_t, nArg int32, pArg, pRowID ptr_t) res_t {
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater)
db := ctx.Value(connKey{}).(*Conn)
args := make([]Value, nArg)
callbackArgs(db, args, pArg)
rowID, err := vtab.Update(args...)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
vtab := vtabGetHandle(ctx, mod, pVTab).(VTabUpdater)
rowID, err := vtab.Update(*args...)
if err == nil {
util.Write64(mod, pRowID, rowID)
}
@ -593,15 +598,17 @@ func cursorCloseCallback(ctx context.Context, mod api.Module, pCur ptr_t) res_t
}
func cursorFilterCallback(ctx context.Context, mod api.Module, pCur ptr_t, idxNum int32, idxStr ptr_t, nArg int32, pArg ptr_t) res_t {
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
db := ctx.Value(connKey{}).(*Conn)
args := make([]Value, nArg)
callbackArgs(db, args, pArg)
args := callbackArgs(db, nArg, pArg)
defer returnArgs(args)
var idxName string
if idxStr != 0 {
idxName = util.ReadString(mod, idxStr, _MAX_LENGTH)
}
err := cursor.Filter(int(idxNum), idxName, args...)
cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor)
err := cursor.Filter(int(idxNum), idxName, *args...)
return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err)
}

3
vendor/modules.txt vendored
View file

@ -691,7 +691,7 @@ github.com/modern-go/reflect2
# github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
## explicit
github.com/munnerz/goautoneg
# github.com/ncruces/go-sqlite3 v0.24.0
# github.com/ncruces/go-sqlite3 v0.25.0
## explicit; go 1.23.0
github.com/ncruces/go-sqlite3
github.com/ncruces/go-sqlite3/driver
@ -699,7 +699,6 @@ github.com/ncruces/go-sqlite3/embed
github.com/ncruces/go-sqlite3/internal/alloc
github.com/ncruces/go-sqlite3/internal/dotlk
github.com/ncruces/go-sqlite3/internal/util
github.com/ncruces/go-sqlite3/util/osutil
github.com/ncruces/go-sqlite3/util/sql3util
github.com/ncruces/go-sqlite3/vfs
github.com/ncruces/go-sqlite3/vfs/memdb

View file

@ -93,29 +93,29 @@
font-family: "Atkinson Hyperlegible";
font-weight: normal;
font-style: normal;
src: url(/assets/fonts/Atkinson-Hyperlegible-Regular-102a.woff2) format('woff2');
src: url(/assets/fonts/Atkinson-Hyperlegible-Regular-102.woff) format('woff');
src: url(/assets/fonts/Atkinson-Hyperlegible-Regular-102a.woff2) format('woff2'),
url(/assets/fonts/Atkinson-Hyperlegible-Regular-102.woff) format('woff');
}
@font-face {
font-family: "Atkinson Hyperlegible";
font-weight: bold;
font-style: normal;
src: url(/assets/fonts/Atkinson-Hyperlegible-Bold-102a.woff2) format('woff2');
src: url(/assets/fonts/Atkinson-Hyperlegible-Bold-102.woff) format('woff');
src: url(/assets/fonts/Atkinson-Hyperlegible-Bold-102a.woff2) format('woff2'),
url(/assets/fonts/Atkinson-Hyperlegible-Bold-102.woff) format('woff');
}
@font-face {
font-family: "Atkinson Hyperlegible";
font-weight: normal;
font-style: italic;
src: url(/assets/fonts/Atkinson-Hyperlegible-Italic-102a.woff2) format('woff2');
src: url(/assets/fonts/Atkinson-Hyperlegible-Italic-102.woff) format('woff');
src: url(/assets/fonts/Atkinson-Hyperlegible-Italic-102a.woff2) format('woff2'),
url(/assets/fonts/Atkinson-Hyperlegible-Italic-102.woff) format('woff');
}
@font-face {
font-family: "Atkinson Hyperlegible";
font-weight: bold;
font-style: italic;
src: url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102a.woff2) format('woff2');
src: url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102.woff) format('woff');
src: url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102a.woff2) format('woff2'),
url(/assets/fonts/Atkinson-Hyperlegible-BoldItalic-102.woff) format('woff');
}
/* Main page background */

View file

@ -40,39 +40,19 @@ function importEntriesProcessor(formData: ImportDomainPermsParams): (_entry: Dom
// Override each obfuscate entry if necessary.
if (formData.obfuscate !== undefined) {
const obfuscateEntry = (entry: DomainPerm) => {
processingFuncs.push((entry: DomainPerm) => {
entry.obfuscate = formData.obfuscate;
};
processingFuncs.push(obfuscateEntry);
});
}
// Check whether we need to append or replace
// private_comment and public_comment.
// Check whether we need to replace
// private_comment and/or public_comment.
["private_comment","public_comment"].forEach((commentType) => {
let text = formData.commentType?.trim();
if (!text) {
return;
}
switch(formData[`${commentType}_behavior`]) {
case "append":
const appendComment = (entry: DomainPerm) => {
if (entry.commentType == undefined) {
entry.commentType = text;
} else {
entry.commentType = [entry.commentType, text].join("\n");
}
};
processingFuncs.push(appendComment);
break;
case "replace":
const replaceComment = (entry: DomainPerm) => {
entry.commentType = text;
};
processingFuncs.push(replaceComment);
break;
if (formData[`replace_${commentType}`]) {
const text = formData[commentType]?.trim();
processingFuncs.push((entry: DomainPerm) => {
entry[commentType] = text;
});
}
});

View file

@ -22,6 +22,7 @@ import { gtsApi } from "../../gts-api";
import {
replaceCacheOnMutation,
removeFromCacheOnMutation,
updateCacheOnMutation,
} from "../../query-modifiers";
import { listToKeyedObject } from "../../transforms";
import type {
@ -55,6 +56,36 @@ const extended = gtsApi.injectEndpoints({
...replaceCacheOnMutation("domainAllows")
}),
updateDomainBlock: build.mutation<DomainPerm, any>({
query: ({ id, ...formData}) => ({
method: "PUT",
url: `/api/v1/admin/domain_blocks/${id}`,
asForm: true,
body: formData,
discardEmpty: false
}),
...updateCacheOnMutation("domainBlocks", {
key: (_draft, newData) => {
return newData.domain;
}
})
}),
updateDomainAllow: build.mutation<DomainPerm, any>({
query: ({ id, ...formData}) => ({
method: "PUT",
url: `/api/v1/admin/domain_allows/${id}`,
asForm: true,
body: formData,
discardEmpty: false
}),
...updateCacheOnMutation("domainAllows", {
key: (_draft, newData) => {
return newData.domain;
}
})
}),
removeDomainBlock: build.mutation<DomainPerm, string>({
query: (id) => ({
method: "DELETE",
@ -91,6 +122,16 @@ const useAddDomainBlockMutation = extended.useAddDomainBlockMutation;
*/
const useAddDomainAllowMutation = extended.useAddDomainAllowMutation;
/**
* Update a single domain permission (block) by PUTing to `/api/v1/admin/domain_blocks/{id}`.
*/
const useUpdateDomainBlockMutation = extended.useUpdateDomainBlockMutation;
/**
* Update a single domain permission (allow) by PUTing to `/api/v1/admin/domain_allows/{id}`.
*/
const useUpdateDomainAllowMutation = extended.useUpdateDomainAllowMutation;
/**
* Remove a single domain permission (block) by DELETEing to `/api/v1/admin/domain_blocks/{id}`.
*/
@ -104,6 +145,8 @@ const useRemoveDomainAllowMutation = extended.useRemoveDomainAllowMutation;
export {
useAddDomainBlockMutation,
useAddDomainAllowMutation,
useUpdateDomainBlockMutation,
useUpdateDomainAllowMutation,
useRemoveDomainBlockMutation,
useRemoveDomainAllowMutation
};

View file

@ -26,10 +26,11 @@ import type {
import type { Theme } from "../../types/theme";
import { User } from "../../types/user";
import { DefaultInteractionPolicies, UpdateDefaultInteractionPolicies } from "../../types/interaction";
import { Account } from "../../types/account";
const extended = gtsApi.injectEndpoints({
endpoints: (build) => ({
updateCredentials: build.mutation({
updateCredentials: build.mutation<Account, any>({
query: (formData) => ({
method: "PATCH",
url: `/api/v1/accounts/update_credentials`,
@ -39,6 +40,22 @@ const extended = gtsApi.injectEndpoints({
}),
...replaceCacheOnMutation("verifyCredentials")
}),
deleteHeader: build.mutation<Account, void>({
query: (_) => ({
method: "DELETE",
url: `/api/v1/profile/header`,
}),
...replaceCacheOnMutation("verifyCredentials")
}),
deleteAvatar: build.mutation<Account, void>({
query: (_) => ({
method: "DELETE",
url: `/api/v1/profile/avatar`,
}),
...replaceCacheOnMutation("verifyCredentials")
}),
user: build.query<User, void>({
query: () => ({url: `/api/v1/user`})
@ -122,6 +139,8 @@ const extended = gtsApi.injectEndpoints({
export const {
useUpdateCredentialsMutation,
useDeleteHeaderMutation,
useDeleteAvatarMutation,
useUserQuery,
usePasswordChangeMutation,
useEmailChangeMutation,

View file

@ -46,8 +46,8 @@ export interface DomainPerm {
valid?: boolean;
checked?: boolean;
commentType?: string;
private_comment_behavior?: "append" | "replace";
public_comment_behavior?: "append" | "replace";
replace_private_comment?: boolean;
replace_public_comment?: boolean;
}
/**
@ -65,8 +65,8 @@ const domainPermStripOnImport: Set<keyof DomainPerm> = new Set([
"valid",
"checked",
"commentType",
"private_comment_behavior",
"public_comment_behavior",
"replace_private_comment",
"replace_public_comment",
]);
/**

View file

@ -447,6 +447,12 @@ section.with-sidebar > form {
flex-direction: column;
justify-content: space-around;
gap: 0.5rem;
.delete-header-button,
.delete-avatar-button {
margin-top: 0.5rem;
font-size: 1rem;
}
}
/*
@ -618,6 +624,15 @@ span.form-info {
}
}
section > div.domain-block,
section > div.domain-allow {
height: 100%;
> a {
margin-top: auto;
}
}
.domain-permissions-list {
p {
margin-top: 0;
@ -976,32 +991,26 @@ button.tab-button {
.domain-perm-import-list {
.checkbox-list-wrapper {
overflow-x: auto;
display: grid;
gap: 1rem;
}
.checkbox-list {
overflow-x: auto;
.header {
align-items: center;
input[type="checkbox"] {
align-self: start;
height: 1.5rem;
}
}
.entry {
gap: 0;
width: 100%;
grid-template-columns: auto minmax(25ch, 2fr) minmax(40ch, 1fr);
grid-template-rows: auto 1fr;
input[type="checkbox"] {
margin-right: 1rem;
}
grid-template-columns: auto max(50%, 14rem) 1fr;
column-gap: 1rem;
align-items: center;
.domain-input {
margin-right: 0.5rem;
display: grid;
grid-template-columns: 1fr $fa-fw;
gap: 0.5rem;
@ -1020,13 +1029,21 @@ button.tab-button {
}
p {
align-self: center;
margin: 0;
grid-column: 4;
grid-row: 1 / span 2;
}
}
}
.set-comment-checkbox {
display: flex;
flex-direction: column;
gap: 0.25rem;
padding: 0.5rem 1rem 1rem 1rem;
width: 100%;
border: 0.1rem solid var(--gray1);
border-radius: 0.1rem;
}
}
.import-export {
@ -1406,6 +1423,7 @@ button.tab-button {
}
}
.domain-permission-details,
.domain-permission-draft-details,
.domain-permission-exclude-details,
.domain-permission-subscription-details {
@ -1414,6 +1432,7 @@ button.tab-button {
}
}
.domain-permission-details,
.domain-permission-drafts-view,
.domain-permission-draft-details,
.domain-permission-subscriptions-view,

View file

@ -32,8 +32,18 @@ import Loading from "../../../components/loading";
import BackButton from "../../../components/back-button";
import MutationButton from "../../../components/form/mutation-button";
import { useDomainAllowsQuery, useDomainBlocksQuery } from "../../../lib/query/admin/domain-permissions/get";
import { useAddDomainAllowMutation, useAddDomainBlockMutation, useRemoveDomainAllowMutation, useRemoveDomainBlockMutation } from "../../../lib/query/admin/domain-permissions/update";
import {
useDomainAllowsQuery,
useDomainBlocksQuery,
} from "../../../lib/query/admin/domain-permissions/get";
import {
useAddDomainAllowMutation,
useAddDomainBlockMutation,
useRemoveDomainAllowMutation,
useRemoveDomainBlockMutation,
useUpdateDomainAllowMutation,
useUpdateDomainBlockMutation,
} from "../../../lib/query/admin/domain-permissions/update";
import { DomainPerm } from "../../../lib/types/domain-permission";
import { NoArg } from "../../../lib/types/query";
import { Error } from "../../../components/error";
@ -41,8 +51,10 @@ import { useBaseUrl } from "../../../lib/navigation/util";
import { PermType } from "../../../lib/types/perm";
import { useCapitalize } from "../../../lib/util";
import { formDomainValidator } from "../../../lib/util/formvalidators";
import UsernameLozenge from "../../../components/username-lozenge";
import { FormSubmitEvent } from "../../../lib/form/types";
export default function DomainPermDetail() {
export default function DomainPermView() {
const baseUrl = useBaseUrl();
const search = useSearch();
@ -101,33 +113,16 @@ export default function DomainPermDetail() {
? blocks[domain]
: allows[domain];
// Render different into content depending on
// if we have a perm already for this domain.
let infoContent: React.JSX.Element;
if (existingPerm === undefined) {
infoContent = (
<span>
No stored {permType} yet, you can add one below:
</span>
);
} else {
infoContent = (
<div className="info">
<i className="fa fa-fw fa-exclamation-triangle" aria-hidden="true"></i>
<b>Editing existing domain {permTypeRaw} isn't implemented yet, <a href="https://github.com/superseriousbusiness/gotosocial/issues/1198" target="_blank" rel="noopener noreferrer">check here for progress</a></b>
</div>
);
}
const title = <span>Domain {permType} for {domain}</span>;
return (
<div>
<h1 className="text-cutoff">
<BackButton to={`~${baseUrl}/${permTypeRaw}`} />
{" "}
Domain {permType} for {domain}
</h1>
{infoContent}
<DomainPermForm
<div className="domain-permission-details">
<h1><BackButton to={`~${baseUrl}/${permTypeRaw}`} /> {title}</h1>
{ existingPerm
? <DomainPermDetails perm={existingPerm} permType={permType} />
: <span>No stored {permType} yet, you can add one below:</span>
}
<CreateOrUpdateDomainPerm
defaultDomain={domain}
perm={existingPerm}
permType={permType}
@ -136,23 +131,75 @@ export default function DomainPermDetail() {
);
}
interface DomainPermFormProps {
interface DomainPermDetailsProps {
perm: DomainPerm,
permType: PermType,
}
function DomainPermDetails({
perm,
permType
}: DomainPermDetailsProps) {
const baseUrl = useBaseUrl();
const [ location ] = useLocation();
const created = useMemo(() => {
if (perm.created_at) {
return new Date(perm.created_at).toDateString();
}
return "unknown";
}, [perm.created_at]);
return (
<dl className="info-list">
<div className="info-list-entry">
<dt>Created</dt>
<dd><time dateTime={perm.created_at}>{created}</time></dd>
</div>
<div className="info-list-entry">
<dt>Created By</dt>
<dd>
<UsernameLozenge
account={perm.created_by}
linkTo={`~/settings/moderation/accounts/${perm.created_by}`}
backLocation={`~${baseUrl}${location}`}
/>
</dd>
</div>
<div className="info-list-entry">
<dt>Domain</dt>
<dd>{perm.domain}</dd>
</div>
<div className="info-list-entry">
<dt>Permission type</dt>
<dd className={`permission-type ${permType}`}>
<i
aria-hidden={true}
className={`fa fa-${permType === "allow" ? "check" : "close"}`}
></i>
{permType}
</dd>
</div>
<div className="info-list-entry">
<dt>Subscription ID</dt>
<dd>{perm.subscription_id ?? "[none]"}</dd>
</div>
</dl>
);
}
interface CreateOrUpdateDomainPermProps {
defaultDomain: string;
perm?: DomainPerm;
permType: PermType;
}
function DomainPermForm({ defaultDomain, perm, permType }: DomainPermFormProps) {
function CreateOrUpdateDomainPerm({
defaultDomain,
perm,
permType
}: CreateOrUpdateDomainPermProps) {
const isExistingPerm = perm !== undefined;
const disabledForm = isExistingPerm
? {
disabled: true,
title: "Domain permissions currently cannot be edited."
}
: {
disabled: false,
title: "",
};
const form = {
domain: useTextInput("domain", {
@ -161,8 +208,8 @@ function DomainPermForm({ defaultDomain, perm, permType }: DomainPermFormProps)
validator: formDomainValidator,
}),
obfuscate: useBoolInput("obfuscate", { source: perm }),
commentPrivate: useTextInput("private_comment", { source: perm }),
commentPublic: useTextInput("public_comment", { source: perm })
privateComment: useTextInput("private_comment", { source: perm }),
publicComment: useTextInput("public_comment", { source: perm })
};
// Check which perm type we're meant to be handling
@ -171,112 +218,132 @@ function DomainPermForm({ defaultDomain, perm, permType }: DomainPermFormProps)
// react is like "weh" (mood), but we can decide
// which ones to use conditionally.
const [ addBlock, addBlockResult ] = useAddDomainBlockMutation();
const [ updateBlock, updateBlockResult ] = useUpdateDomainBlockMutation({ fixedCacheKey: perm?.id });
const [ removeBlock, removeBlockResult] = useRemoveDomainBlockMutation({ fixedCacheKey: perm?.id });
const [ addAllow, addAllowResult ] = useAddDomainAllowMutation();
const [ updateAllow, updateAllowResult ] = useUpdateDomainAllowMutation({ fixedCacheKey: perm?.id });
const [ removeAllow, removeAllowResult ] = useRemoveDomainAllowMutation({ fixedCacheKey: perm?.id });
const [
addTrigger,
addResult,
createOrUpdateTrigger,
createOrUpdateResult,
removeTrigger,
removeResult,
] = useMemo(() => {
return permType == "block"
? [
addBlock,
addBlockResult,
removeBlock,
removeBlockResult,
]
: [
addAllow,
addAllowResult,
removeAllow,
removeAllowResult,
];
}, [permType,
addBlock, addBlockResult, removeBlock, removeBlockResult,
addAllow, addAllowResult, removeAllow, removeAllowResult,
switch (true) {
case (permType === "block" && !isExistingPerm):
return [ addBlock, addBlockResult, removeBlock, removeBlockResult ];
case (permType === "block"):
return [ updateBlock, updateBlockResult, removeBlock, removeBlockResult ];
case !isExistingPerm:
return [ addAllow, addAllowResult, removeAllow, removeAllowResult ];
default:
return [ updateAllow, updateAllowResult, removeAllow, removeAllowResult ];
}
}, [permType, isExistingPerm,
addBlock, addBlockResult, updateBlock, updateBlockResult, removeBlock, removeBlockResult,
addAllow, addAllowResult, updateAllow, updateAllowResult, removeAllow, removeAllowResult,
]);
// Use appropriate submission params for this permType.
const [submitForm, submitFormResult] = useFormSubmit(form, [addTrigger, addResult], { changedOnly: false });
// Use appropriate submission params for this
// permType, and whether we're creating or updating.
const [submit, submitResult] = useFormSubmit(
form,
[ createOrUpdateTrigger, createOrUpdateResult ],
{
changedOnly: isExistingPerm,
// If we're updating an existing perm,
// insert the perm ID into the mutation
// data before submitting. Otherwise just
// return the mutationData unmodified.
customizeMutationArgs: (mutationData) => {
if (isExistingPerm) {
return {
id: perm?.id,
...mutationData,
};
} else {
return mutationData;
}
},
},
);
// Uppercase first letter of given permType.
const permTypeUpper = useCapitalize(permType);
const [location, setLocation] = useLocation();
function verifyUrlThenSubmit(e) {
function onSubmit(e: FormSubmitEvent) {
// Adding a new domain permissions happens on a url like
// "/settings/admin/domain-permissions/:permType/domain.com",
// but if domain input changes, that doesn't match anymore
// and causes issues later on so, before submitting the form,
// silently change url, and THEN submit.
let correctUrl = `/${permType}s/${form.domain.value}`;
if (location != correctUrl) {
setLocation(correctUrl);
if (!isExistingPerm) {
let correctUrl = `/${permType}s/${form.domain.value}`;
if (location != correctUrl) {
setLocation(correctUrl);
}
}
return submitForm(e);
return submit(e);
}
return (
<form onSubmit={verifyUrlThenSubmit}>
<TextInput
field={form.domain}
label="Domain"
placeholder="example.com"
autoCapitalize="none"
spellCheck="false"
{...disabledForm}
/>
<form onSubmit={onSubmit}>
{ !isExistingPerm &&
<TextInput
field={form.domain}
label="Domain"
placeholder="example.com"
autoCapitalize="none"
spellCheck="false"
/>
}
<Checkbox
field={form.obfuscate}
label="Obfuscate domain in public lists"
{...disabledForm}
/>
<TextArea
field={form.commentPrivate}
field={form.privateComment}
label="Private comment"
autoCapitalize="sentences"
rows={3}
{...disabledForm}
/>
<TextArea
field={form.commentPublic}
field={form.publicComment}
label="Public comment"
autoCapitalize="sentences"
rows={3}
{...disabledForm}
/>
<div className="action-buttons row">
<MutationButton
label={permTypeUpper}
result={submitFormResult}
showError={false}
{...disabledForm}
label={isExistingPerm ? "Update " + permType.toString() : permTypeUpper}
result={submitResult}
disabled={
isExistingPerm &&
!form.obfuscate.hasChanged() &&
!form.privateComment.hasChanged() &&
!form.publicComment.hasChanged()
}
/>
{
isExistingPerm &&
<MutationButton
{ isExistingPerm &&
<button
type="button"
onClick={() => removeTrigger(perm.id?? "")}
label="Remove"
result={removeResult}
className="button danger"
showError={false}
disabled={!isExistingPerm}
/>
>
Remove {permType.toString()}
</button>
}
</div>
<>
{addResult.error && <Error error={addResult.error} />}
{createOrUpdateResult.error && <Error error={createOrUpdateResult.error} />}
{removeResult.error && <Error error={removeResult.error} />}
</>

View file

@ -61,7 +61,7 @@ export default function ImportExport() {
>
&lt; back
</span>
&nbsp; Confirm import of domain {form.permType.value}s:
&nbsp; Confirm {form.permType.value}s:
</h1>
<ProcessImport
list={parseResult.data}

View file

@ -24,14 +24,12 @@ import { isValidDomainPermission, hasBetterScope } from "../../../lib/util/domai
import {
useTextInput,
useBoolInput,
useRadioInput,
useCheckListInput,
} from "../../../lib/form";
import {
Select,
TextArea,
RadioGroup,
Checkbox,
TextInput,
} from "../../../components/form/inputs";
@ -113,84 +111,81 @@ function ImportList({ list, data: domainPerms, permType }: ImportListProps) {
privateComment: useTextInput("private_comment", {
defaultValue: `Imported on ${new Date().toLocaleString()}`
}),
privateCommentBehavior: useRadioInput("private_comment_behavior", {
defaultValue: "append",
options: {
append: "Append to",
replace: "Replace"
}
}),
replacePrivateComment: useBoolInput("replace_private_comment", { defaultValue: false }),
publicComment: useTextInput("public_comment"),
publicCommentBehavior: useRadioInput("public_comment_behavior", {
defaultValue: "append",
options: {
append: "Append to",
replace: "Replace"
}
}),
replacePublicComment: useBoolInput("replace_public_comment", { defaultValue: false }),
permType: permType,
};
const [importDomains, importResult] = useFormSubmit(form, useImportDomainPermsMutation(), { changedOnly: false });
const [importDomains, importResult] = useFormSubmit(
form,
useImportDomainPermsMutation(),
{ changedOnly: false },
);
return (
<>
<form
onSubmit={importDomains}
className="domain-perm-import-list"
>
<span>{list.length} domain{list.length != 1 ? "s" : ""} in this list</span>
<form
onSubmit={importDomains}
className="domain-perm-import-list"
>
<span>{list.length} domain{list.length != 1 ? "s" : ""} in this list</span>
{hasComment.both &&
{hasComment.both &&
<Select field={showComment} options={
<>
<option value="public_comment">Show public comments</option>
<option value="private_comment">Show private comments</option>
</>
} />
}
}
<div className="checkbox-list-wrapper">
<DomainCheckList
field={form.domains}
domainPerms={domainPerms}
commentType={showComment.value as "public_comment" | "private_comment"}
permType={form.permType}
/>
</div>
<div className="checkbox-list-wrapper">
<DomainCheckList
field={form.domains}
domainPerms={domainPerms}
commentType={showComment.value as "public_comment" | "private_comment"}
permType={form.permType}
/>
</div>
<Checkbox
field={form.obfuscate}
label="Obfuscate domains in public lists"
/>
<div className="set-comment-checkbox">
<Checkbox
field={form.replacePrivateComment}
label="Set/replace private comment(s) to:"
/>
<TextArea
field={form.privateComment}
label="Private comment"
rows={3}
disabled={!form.replacePrivateComment.value}
placeholder="Private comment"
/>
<RadioGroup
field={form.privateCommentBehavior}
label="imported private comment"
/>
</div>
<div className="set-comment-checkbox">
<Checkbox
field={form.replacePublicComment}
label="Set/replace public comment(s) to:"
/>
<TextArea
field={form.publicComment}
label="Public comment"
rows={3}
disabled={!form.replacePublicComment.value}
placeholder="Public comment"
/>
<RadioGroup
field={form.publicCommentBehavior}
label="imported public comment"
/>
</div>
<Checkbox
field={form.obfuscate}
label="Obfuscate domains in public lists"
/>
<MutationButton
label="Import"
disabled={false}
result={importResult}
/>
</form>
</>
<MutationButton
label="Import"
disabled={false}
result={importResult}
/>
</form>
);
}

View file

@ -25,7 +25,7 @@ import ReportDetail from "./reports/detail";
import { ErrorBoundary } from "../../lib/navigation/error";
import ImportExport from "./domain-permissions/import-export";
import DomainPermissionsOverview from "./domain-permissions/overview";
import DomainPermDetail from "./domain-permissions/detail";
import DomainPermView from "./domain-permissions/detail";
import AccountsSearch from "./accounts";
import AccountsPending from "./accounts/pending";
import AccountDetail from "./accounts/detail";
@ -160,7 +160,7 @@ function ModerationDomainPermsRouter() {
<Route path="/subscriptions/preview" component={DomainPermissionSubscriptionsPreview} />
<Route path="/subscriptions/:permSubId" component={DomainPermissionSubscriptionDetail} />
<Route path="/:permType" component={DomainPermissionsOverview} />
<Route path="/:permType/:domain" component={DomainPermDetail} />
<Route path="/:permType/:domain" component={DomainPermView} />
<Route><Redirect to="/blocks"/></Route>
</Switch>
</ErrorBoundary>

View file

@ -17,7 +17,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
import React, { useMemo } from "react";
import React, { useMemo, useState } from "react";
import {
useTextInput,
@ -41,7 +41,7 @@ import FormWithData from "../../lib/form/form-with-data";
import FakeProfile from "../../components/profile";
import MutationButton from "../../components/form/mutation-button";
import { useAccountThemesQuery } from "../../lib/query/user";
import { useAccountThemesQuery, useDeleteAvatarMutation, useDeleteHeaderMutation } from "../../lib/query/user";
import { useUpdateCredentialsMutation } from "../../lib/query/user";
import { useVerifyCredentialsQuery } from "../../lib/query/login";
import { useInstanceV1Query } from "../../lib/query/gts-api";
@ -116,17 +116,23 @@ function UserProfileForm({ data: profile }: UserProfileFormProps) {
theme: useTextInput("theme", { source: profile }),
};
const [ noHeader, setNoHeader ] = useState(!profile.header_media_id);
const [ deleteHeader, deleteHeaderRes ] = useDeleteHeaderMutation();
const [ noAvatar, setNoAvatar ] = useState(!profile.avatar_media_id);
const [ deleteAvatar, deleteAvatarRes ] = useDeleteAvatarMutation();
const [submitForm, result] = useFormSubmit(form, useUpdateCredentialsMutation(), {
changedOnly: true,
onFinish: () => {
form.avatar.reset();
form.header.reset();
onFinish: (res) => {
if ('data' in res) {
form.avatar.reset();
form.header.reset();
setNoAvatar(!res.data.avatar_media_id);
setNoHeader(!res.data.header_media_id);
}
}
});
const noAvatarSet = !profile.avatar_media_id;
const noHeaderSet = !profile.header_media_id;
return (
<form className="user-profile" onSubmit={submitForm}>
<h1>Profile</h1>
@ -152,7 +158,21 @@ function UserProfileForm({ data: profile }: UserProfileFormProps) {
label="Image description; only settable if not using default header"
placeholder="A green field with pink flowers."
autoCapitalize="sentences"
disabled={noHeaderSet && !form.header.value}
disabled={noHeader && !form.header.value}
/>
<MutationButton
className="delete-header-button"
label="Delete header"
disabled={noHeader}
result={deleteHeaderRes}
onClick={(e) => {
e.preventDefault();
deleteHeader().then(res => {
if ('data' in res) {
setNoHeader(true);
}
});
}}
/>
</fieldset>
@ -168,7 +188,21 @@ function UserProfileForm({ data: profile }: UserProfileFormProps) {
label="Image description; only settable if not using default avatar"
placeholder="A cute drawing of a smiling sloth."
autoCapitalize="sentences"
disabled={noAvatarSet && !form.avatar.value}
disabled={noAvatar && !form.avatar.value}
/>
<MutationButton
className="delete-avatar-button"
label="Delete avatar"
disabled={noAvatar}
result={deleteAvatarRes}
onClick={(e) => {
e.preventDefault();
deleteAvatar().then(res => {
if ('data' in res) {
setNoAvatar(true);
}
});
}}
/>
</fieldset>