/*
   GoToSocial
   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU Affero General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU Affero General Public License for more details.

   You should have received a copy of the GNU Affero General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

package bundb

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"database/sql"
	"encoding/pem"
	"errors"
	"fmt"
	"os"
	"runtime"
	"strings"
	"time"

	"github.com/google/uuid"
	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/stdlib"
	"github.com/superseriousbusiness/gotosocial/internal/config"
	"github.com/superseriousbusiness/gotosocial/internal/db"
	"github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations"
	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
	"github.com/superseriousbusiness/gotosocial/internal/id"
	"github.com/superseriousbusiness/gotosocial/internal/log"
	"github.com/uptrace/bun"
	"github.com/uptrace/bun/dialect/pgdialect"
	"github.com/uptrace/bun/dialect/sqlitedialect"
	"github.com/uptrace/bun/migrate"

	"modernc.org/sqlite"
)

const (
	dbTypePostgres = "postgres"
	dbTypeSqlite   = "sqlite"

	// dbTLSModeDisable does not attempt to make a TLS connection to the database.
	dbTLSModeDisable = "disable"
	// dbTLSModeEnable attempts to make a TLS connection to the database, but doesn't fail if
	// the certificate passed by the database isn't verified.
	dbTLSModeEnable = "enable"
	// dbTLSModeRequire attempts to make a TLS connection to the database, and requires
	// that the certificate presented by the database is valid.
	dbTLSModeRequire = "require"
	// dbTLSModeUnset means that the TLS mode has not been set.
	dbTLSModeUnset = ""
)

var registerTables = []interface{}{
	&gtsmodel.AccountToEmoji{},
	&gtsmodel.StatusToEmoji{},
	&gtsmodel.StatusToTag{},
}

// DBService satisfies the DB interface
type DBService struct {
	db.Account
	db.Admin
	db.Basic
	db.Domain
	db.Emoji
	db.Instance
	db.Media
	db.Mention
	db.Notification
	db.Relationship
	db.Session
	db.Status
	db.Timeline
	db.User
	db.Tombstone
	conn *DBConn
}

// GetConn returns the underlying bun connection.
// Should only be used in testing + exceptional circumstance.
func (dbService *DBService) GetConn() *DBConn {
	return dbService.conn
}

func doMigration(ctx context.Context, db *bun.DB) error {
	migrator := migrate.NewMigrator(db, migrations.Migrations)

	if err := migrator.Init(ctx); err != nil {
		return err
	}

	group, err := migrator.Migrate(ctx)
	if err != nil {
		if err.Error() == "migrate: there are no any migrations" {
			return nil
		}
		return err
	}

	if group.ID == 0 {
		log.Info("there are no new migrations to run")
		return nil
	}

	log.Infof("MIGRATED DATABASE TO %s", group)
	return nil
}

// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context) (db.DB, error) {
	var conn *DBConn
	var err error
	dbType := strings.ToLower(config.GetDbType())

	switch dbType {
	case dbTypePostgres:
		conn, err = pgConn(ctx)
		if err != nil {
			return nil, err
		}
	case dbTypeSqlite:
		conn, err = sqliteConn(ctx)
		if err != nil {
			return nil, err
		}
	default:
		return nil, fmt.Errorf("database type %s not supported for bundb", dbType)
	}

	// Add database query hook
	conn.DB.AddQueryHook(queryHook{})

	// table registration is needed for many-to-many, see:
	// https://bun.uptrace.dev/orm/many-to-many-relation/
	for _, t := range registerTables {
		conn.RegisterModel(t)
	}

	// perform any pending database migrations: this includes
	// the very first 'migration' on startup which just creates
	// necessary tables
	if err := doMigration(ctx, conn.DB); err != nil {
		return nil, fmt.Errorf("db migration error: %s", err)
	}

	// Create DB structs that require ptrs to each other
	account := &accountDB{conn: conn}
	admin := &adminDB{conn: conn}
	domain := &domainDB{conn: conn}
	mention := &mentionDB{conn: conn}
	notif := &notificationDB{conn: conn}
	status := &statusDB{conn: conn}
	emoji := &emojiDB{conn: conn}
	timeline := &timelineDB{conn: conn}
	tombstone := &tombstoneDB{conn: conn}
	user := &userDB{conn: conn}

	// Setup DB cross-referencing
	account.emojis = emoji
	account.status = status
	admin.users = user
	status.accounts = account
	status.emojis = emoji
	status.mentions = mention
	timeline.status = status

	// Initialize db structs
	account.init()
	domain.init()
	emoji.init()
	mention.init()
	notif.init()
	status.init()
	tombstone.init()
	user.init()

	ps := &DBService{
		Account: account,
		Admin: &adminDB{
			conn:     conn,
			accounts: account,
			users:    user,
		},
		Basic: &basicDB{
			conn: conn,
		},
		Domain: domain,
		Emoji:  emoji,
		Instance: &instanceDB{
			conn: conn,
		},
		Media: &mediaDB{
			conn: conn,
		},
		Mention:      mention,
		Notification: notif,
		Relationship: &relationshipDB{
			conn: conn,
		},
		Session: &sessionDB{
			conn: conn,
		},
		Status:    status,
		Timeline:  timeline,
		User:      user,
		Tombstone: tombstone,
		conn:      conn,
	}

	// we can confidently return this useable service now
	return ps, nil
}

func sqliteConn(ctx context.Context) (*DBConn, error) {
	// validate db address has actually been set
	dbAddress := config.GetDbAddress()
	if dbAddress == "" {
		return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.DbAddressFlag())
	}

	// Drop anything fancy from DB address
	dbAddress = strings.Split(dbAddress, "?")[0]
	dbAddress = strings.TrimPrefix(dbAddress, "file:")

	// Append our own SQLite preferences
	dbAddress = "file:" + dbAddress + "?cache=shared"

	var inMem bool

	if dbAddress == "file::memory:?cache=shared" {
		dbAddress = fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid.NewString())
		log.Infof("using in-memory database address " + dbAddress)
		log.Warn("sqlite in-memory database should only be used for debugging")
		inMem = true
	}

	// Open new DB instance
	sqldb, err := sql.Open("sqlite", dbAddress)
	if err != nil {
		if errWithCode, ok := err.(*sqlite.Error); ok {
			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
		}
		return nil, fmt.Errorf("could not open sqlite db: %s", err)
	}

	tweakConnectionValues(sqldb)

	if inMem {
		// don't close connections on disconnect -- otherwise
		// the SQLite database will be deleted when there
		// are no active connections
		sqldb.SetConnMaxLifetime(0)
	}

	conn := WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()))

	// ping to check the db is there and listening
	if err := conn.PingContext(ctx); err != nil {
		if errWithCode, ok := err.(*sqlite.Error); ok {
			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
		}
		return nil, fmt.Errorf("sqlite ping: %s", err)
	}

	log.Info("connected to SQLITE database")
	return conn, nil
}

func pgConn(ctx context.Context) (*DBConn, error) {
	opts, err := deriveBunDBPGOptions() //nolint:contextcheck
	if err != nil {
		return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
	}

	sqldb := stdlib.OpenDB(*opts)

	tweakConnectionValues(sqldb)

	conn := WrapDBConn(bun.NewDB(sqldb, pgdialect.New()))

	// ping to check the db is there and listening
	if err := conn.PingContext(ctx); err != nil {
		return nil, fmt.Errorf("postgres ping: %s", err)
	}

	log.Info("connected to POSTGRES database")
	return conn, nil
}

/*
	HANDY STUFF
*/

// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
	if strings.ToUpper(config.GetDbType()) != db.DBTypePostgres {
		return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, config.DbTypeFlag())
	}

	// these are all optional, the db adapter figures out defaults
	address := config.GetDbAddress()

	// validate database
	database := config.GetDbDatabase()
	if database == "" {
		return nil, errors.New("no database set")
	}

	var tlsConfig *tls.Config
	switch config.GetDbTLSMode() {
	case dbTLSModeDisable, dbTLSModeUnset:
		break // nothing to do
	case dbTLSModeEnable:
		/* #nosec G402 */
		tlsConfig = &tls.Config{
			InsecureSkipVerify: true,
		}
	case dbTLSModeRequire:
		tlsConfig = &tls.Config{
			InsecureSkipVerify: false,
			ServerName:         address,
			MinVersion:         tls.VersionTLS12,
		}
	}

	if certPath := config.GetDbTLSCACert(); tlsConfig != nil && certPath != "" {
		// load the system cert pool first -- we'll append the given CA cert to this
		certPool, err := x509.SystemCertPool()
		if err != nil {
			return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
		}

		// open the file itself and make sure there's something in it
		caCertBytes, err := os.ReadFile(certPath)
		if err != nil {
			return nil, fmt.Errorf("error opening CA certificate at %s: %s", certPath, err)
		}
		if len(caCertBytes) == 0 {
			return nil, fmt.Errorf("ca cert at %s was empty", certPath)
		}

		// make sure we have a PEM block
		caPem, _ := pem.Decode(caCertBytes)
		if caPem == nil {
			return nil, fmt.Errorf("could not parse cert at %s into PEM", certPath)
		}

		// parse the PEM block into the certificate
		caCert, err := x509.ParseCertificate(caPem.Bytes)
		if err != nil {
			return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)
		}

		// we're happy, add it to the existing pool and then use this pool in our tls config
		certPool.AddCert(caCert)
		tlsConfig.RootCAs = certPool
	}

	cfg, _ := pgx.ParseConfig("")
	if address != "" {
		cfg.Host = address
	}
	if port := config.GetDbPort(); port > 0 {
		cfg.Port = uint16(port)
	}
	if u := config.GetDbUser(); u != "" {
		cfg.User = u
	}
	if p := config.GetDbPassword(); p != "" {
		cfg.Password = p
	}
	if tlsConfig != nil {
		cfg.TLSConfig = tlsConfig
	}
	cfg.Database = database
	cfg.PreferSimpleProtocol = true
	cfg.RuntimeParams["application_name"] = config.GetApplicationName()

	return cfg, nil
}

// https://bun.uptrace.dev/postgres/running-bun-in-production.html#database-sql
func tweakConnectionValues(sqldb *sql.DB) {
	maxOpenConns := 4 * runtime.GOMAXPROCS(0)
	sqldb.SetMaxOpenConns(maxOpenConns)
	sqldb.SetMaxIdleConns(maxOpenConns)
}

/*
	CONVERSION FUNCTIONS
*/

func (dbService *DBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string) ([]*gtsmodel.Tag, error) {
	protocol := config.GetProtocol()
	host := config.GetHost()

	newTags := []*gtsmodel.Tag{}
	for _, t := range tags {
		tag := &gtsmodel.Tag{}
		// we can use selectorinsert here to create the new tag if it doesn't exist already
		// inserted will be true if this is a new tag we just created
		if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
			if err == sql.ErrNoRows {
				// tag doesn't exist yet so populate it
				newID, err := id.NewRandomULID()
				if err != nil {
					return nil, err
				}
				tag.ID = newID
				tag.URL = fmt.Sprintf("%s://%s/tags/%s", protocol, host, t)
				tag.Name = t
				tag.FirstSeenFromAccountID = originAccountID
				tag.CreatedAt = time.Now()
				tag.UpdatedAt = time.Now()
				useable := true
				tag.Useable = &useable
				listable := true
				tag.Listable = &listable
			} else {
				return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
			}
		}

		// bail already if the tag isn't useable
		if !*tag.Useable {
			continue
		}
		tag.LastStatusAt = time.Now()
		newTags = append(newTags, tag)
	}
	return newTags, nil
}