...

Source file src/codeberg.org/tslocum/sriracha/internal/database/database.go

Documentation: codeberg.org/tslocum/sriracha/internal/database

     1  // Package database provides methods for interacting with a Sriracha database.
     2  package database
     3  
     4  import (
     5  	"context"
     6  	"crypto/rand"
     7  	"encoding/base64"
     8  	"fmt"
     9  	"log"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"codeberg.org/tslocum/sriracha"
    15  	. "codeberg.org/tslocum/sriracha/util"
    16  	"github.com/alexedwards/argon2id"
    17  	"github.com/gabriel-vasile/mimetype"
    18  	"github.com/jackc/pgx/v5"
    19  	"github.com/jackc/pgx/v5/pgconn"
    20  	"github.com/jackc/pgx/v5/pgxpool"
    21  )
    22  
    23  var argon2idParameters = &argon2id.Params{
    24  	Memory:      128 * 1024,
    25  	Iterations:  2,
    26  	Parallelism: 2,
    27  	SaltLength:  16,
    28  	KeyLength:   64,
    29  }
    30  
    31  // DB represents a database connection.
    32  type DB struct {
    33  	conn      *pgxpool.Conn
    34  	Plugin    string
    35  	config    *Config
    36  	committed bool
    37  }
    38  
    39  func Connect(c *Config) (*pgxpool.Pool, error) {
    40  	url := c.DBURL
    41  	if strings.TrimSpace(url) == "" {
    42  		url = fmt.Sprintf("postgres://%s:%s@%s/%s", c.Username, c.Password, c.Address, c.DBName)
    43  	}
    44  
    45  	config, err := pgxpool.ParseConfig(url)
    46  	if err != nil {
    47  		return nil, fmt.Errorf("failed to parse database configuration: %s", err)
    48  	}
    49  	config.MinConns = 1
    50  	config.MinIdleConns = 1
    51  	config.MaxConns = 1
    52  
    53  	pool, err := pgxpool.NewWithConfig(context.Background(), config)
    54  	if err != nil {
    55  		return nil, fmt.Errorf("failed to connect to database: %s", err)
    56  	}
    57  
    58  	conn, err := pool.Acquire(context.Background())
    59  	if err != nil {
    60  		return nil, fmt.Errorf("failed to acquire conn: %s", err)
    61  	}
    62  	defer conn.Release()
    63  
    64  	_, err = conn.Exec(context.Background(), "BEGIN")
    65  	if err != nil {
    66  		return nil, fmt.Errorf("failed to begin transaction: %s", err)
    67  	}
    68  
    69  	db := &DB{
    70  		conn:   conn,
    71  		config: c,
    72  	}
    73  	err = db.initialize()
    74  	if err != nil {
    75  		return nil, fmt.Errorf("failed to initialize database: %s", err)
    76  	}
    77  
    78  	err = db.upgrade(c.Root)
    79  	if err != nil {
    80  		return nil, fmt.Errorf("failed to upgrade database: %s", err)
    81  	}
    82  
    83  	db.createSuperAdminAccount(c.SaltPass)
    84  
    85  	_, err = conn.Exec(context.Background(), "COMMIT")
    86  	if err != nil {
    87  		return nil, fmt.Errorf("failed to commit transaction: %s", err)
    88  	}
    89  	return pool, nil
    90  }
    91  
    92  func (db *DB) initialize() error {
    93  	_, err := db.conn.Exec(context.Background(), "SELECT 1=1")
    94  	if err != nil {
    95  		return fmt.Errorf("failed to test database connection: %s", err)
    96  	}
    97  
    98  	var tablecount int
    99  	err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'account'").Scan(&tablecount)
   100  	if err != nil {
   101  		return fmt.Errorf("failed to select whether account table exists: %s", err)
   102  	} else if tablecount > 0 {
   103  		return nil
   104  	}
   105  
   106  	fmt.Printf("Initializing database version 1...\n")
   107  	_, err = db.conn.Exec(context.Background(), dbSchema[0])
   108  	if err != nil {
   109  		return fmt.Errorf("failed to create database: %s", err)
   110  	}
   111  	fmt.Printf("Database initialized.\n")
   112  	return nil
   113  }
   114  
   115  func (db *DB) _upgrade(rootDir string, v int) error {
   116  	_, err := db.conn.Exec(context.Background(), dbSchema[v-1])
   117  	if err != nil {
   118  		return err
   119  	}
   120  	switch v {
   121  	case 5: // Add file MIME type to posts.
   122  		boards := db.AllBoards()
   123  		for _, b := range boards {
   124  			allThreads := db.AllThreads(b, false)
   125  			for _, threadInfo := range allThreads {
   126  				posts := db.AllPostsInThread(threadInfo[0], false)
   127  				for _, post := range posts {
   128  					if post.File != "" && !post.IsEmbed() {
   129  						if strings.HasSuffix(post.File, ".tgkr") {
   130  							post.FileMIME = "application/x-tegaki"
   131  						} else {
   132  							mimeInfo, err := mimetype.DetectFile(filepath.Join(rootDir, b.Dir, "src", post.File))
   133  							if err == nil {
   134  								post.FileMIME = mimeInfo.String()
   135  							}
   136  						}
   137  						if post.FileMIME != "" {
   138  							_, err = db.conn.Exec(context.Background(), "UPDATE post SET filemime = $1 WHERE id = $2", post.FileMIME, post.ID)
   139  							if err != nil {
   140  								return err
   141  							}
   142  						}
   143  					}
   144  				}
   145  			}
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func (db *DB) upgrade(rootDir string) error {
   152  	var versionString string
   153  	err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = 'version'").Scan(&versionString)
   154  	if err != nil {
   155  		return fmt.Errorf("failed to select database version: %s", err)
   156  	}
   157  	version, err := strconv.Atoi(versionString)
   158  	if err != nil {
   159  		return fmt.Errorf("failed to parse database version: %s", err)
   160  	}
   161  	maxVersion := len(dbSchema)
   162  	if version == maxVersion {
   163  		return nil
   164  	} else if version > maxVersion {
   165  		return fmt.Errorf("database version %d is newer than application version %d", version, maxVersion)
   166  	}
   167  	fmt.Printf("Upgrading database from version %d to %d...\n", version, maxVersion)
   168  	for v := version + 1; v <= maxVersion; v++ {
   169  		err = db._upgrade(rootDir, v)
   170  		if err != nil {
   171  			return fmt.Errorf("failed to upgrade database from version %d to version %d: %s", v-1, v, err)
   172  		}
   173  	}
   174  	fmt.Printf("Database upgraded.\n")
   175  	return nil
   176  }
   177  
   178  func Begin(pool *pgxpool.Pool, config *Config) *DB {
   179  	if pool == nil {
   180  		// Return mock database.
   181  		return &DB{
   182  			config: config,
   183  		}
   184  	}
   185  
   186  	conn, err := pool.Acquire(context.Background())
   187  	if err != nil {
   188  		log.Fatalf("failed to acquire connection from pool: %s", err)
   189  	}
   190  
   191  	_, err = conn.Exec(context.Background(), "BEGIN")
   192  	if err != nil {
   193  		conn.Release()
   194  		log.Fatalf("failed to begin transaction: %s", err)
   195  	}
   196  
   197  	return &DB{
   198  		conn:   conn,
   199  		config: config,
   200  	}
   201  }
   202  
   203  func (db *DB) SoftRollBack() {
   204  	if db.conn == nil {
   205  		return
   206  	}
   207  	_, err := db.conn.Exec(context.Background(), "ROLLBACK")
   208  	if err != nil {
   209  		log.Fatalf("failed to rollback transaction: %s", err)
   210  	}
   211  }
   212  
   213  func (db *DB) RollBack() {
   214  	if db.conn == nil {
   215  		return
   216  	}
   217  	_, err := db.conn.Exec(context.Background(), "ROLLBACK")
   218  	if err != nil {
   219  		log.Fatalf("failed to rollback transaction: %s", err)
   220  	}
   221  	db.conn.Release()
   222  }
   223  
   224  func (db *DB) Commit() {
   225  	if db.conn == nil || db.committed {
   226  		return
   227  	}
   228  	_, err := db.conn.Exec(context.Background(), "COMMIT")
   229  	if err != nil {
   230  		log.Fatalf("failed to commit transaction: %s", err)
   231  	}
   232  	db.conn.Release()
   233  	db.committed = true
   234  }
   235  
   236  func (db *DB) CommitWithErr() error {
   237  	if db.conn == nil || db.committed {
   238  		return nil
   239  	}
   240  	_, err := db.conn.Exec(context.Background(), "COMMIT")
   241  	if err != nil {
   242  		return fmt.Errorf("failed to commit transaction: %s", err)
   243  	}
   244  	db.conn.Release()
   245  	db.committed = true
   246  	return nil
   247  }
   248  
   249  func (db *DB) configKey(key string) string {
   250  	key = strings.ToLower(key)
   251  	if len(db.Plugin) != 0 {
   252  		return db.Plugin + "." + key
   253  	}
   254  	return key
   255  }
   256  
   257  func (db *DB) HaveConfig(key string) bool {
   258  	if db.conn == nil {
   259  		return false
   260  	}
   261  	key = db.configKey(key)
   262  	var count int
   263  	err := db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM config WHERE name = $1", key).Scan(&count)
   264  	if err == pgx.ErrNoRows {
   265  		return false
   266  	} else if err != nil {
   267  		log.Fatalf("failed to select config count %s: %s", key, err)
   268  	}
   269  	return count > 0
   270  }
   271  
   272  func (db *DB) GetString(key string) string {
   273  	if db.conn == nil {
   274  		return ""
   275  	}
   276  	key = db.configKey(key)
   277  	var value string
   278  	err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = $1", key).Scan(&value)
   279  	if err == pgx.ErrNoRows {
   280  		return ""
   281  	} else if err != nil {
   282  		log.Fatalf("failed to get string %s: %s", key, err)
   283  	}
   284  	return value
   285  }
   286  
   287  func (db *DB) SaveString(key string, value string) {
   288  	if db.conn == nil {
   289  		return
   290  	}
   291  	value = strings.ReplaceAll(value, "\r", "")
   292  	_, err := db.conn.Exec(context.Background(), "INSERT INTO config VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET value = $3", db.configKey(key), value, value)
   293  	if err != nil {
   294  		log.Fatalf("failed to save string: %s", err)
   295  	}
   296  }
   297  
   298  func (db *DB) GetMultiString(key string) []string {
   299  	return strings.Split(db.GetString(key), "|||")
   300  }
   301  
   302  func (db *DB) GetBool(key string) bool {
   303  	return db.GetString(key) == "1"
   304  }
   305  
   306  func (db *DB) SaveBool(key string, value bool) {
   307  	v := "0"
   308  	if value {
   309  		v = "1"
   310  	}
   311  	db.SaveString(key, v)
   312  }
   313  
   314  func (db *DB) SaveMultiString(key string, value []string) {
   315  	db.SaveString(key, strings.Join(value, "|||"))
   316  }
   317  
   318  func (db *DB) GetInt(key string) int {
   319  	return ParseInt(db.GetString(key))
   320  }
   321  
   322  func (db *DB) SaveInt(key string, value int) {
   323  	db.SaveString(key, strconv.Itoa(value))
   324  }
   325  
   326  func (db *DB) GetInt64(key string) int64 {
   327  	return ParseInt64(db.GetString(key))
   328  }
   329  
   330  func (db *DB) SaveInt64(key string, value int64) {
   331  	db.SaveString(key, fmt.Sprintf("%d", value))
   332  }
   333  
   334  func (db *DB) GetMultiInt(key string) []int {
   335  	s := db.GetString(key)
   336  	if s == "" {
   337  		return nil
   338  	}
   339  	var values []int
   340  	for _, v := range strings.Split(s, "|||") {
   341  		values = append(values, ParseInt(v))
   342  	}
   343  	return values
   344  }
   345  
   346  func (db *DB) SaveMultiInt(key string, values []int) {
   347  	var out string
   348  	for i, v := range values {
   349  		if i != 0 {
   350  			out += "|||"
   351  		}
   352  		out += strconv.Itoa(v)
   353  	}
   354  	db.SaveString(key, out)
   355  }
   356  
   357  func (db *DB) GetFloat(key string) float64 {
   358  	return ParseFloat(db.GetString(key))
   359  }
   360  
   361  func (db *DB) SaveFloat(key string, value float64) {
   362  	db.SaveString(key, fmt.Sprintf("%f", value))
   363  }
   364  
   365  func (db *DB) newSessionKey() string {
   366  	const keyLength = 48
   367  	buf := make([]byte, keyLength)
   368  	for {
   369  		_, err := rand.Read(buf)
   370  		if err != nil {
   371  			panic(err)
   372  		}
   373  		sessionKey := base64.URLEncoding.EncodeToString(buf)
   374  
   375  		var numAccounts int
   376  		err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM account WHERE session = $1", sessionKey).Scan(&numAccounts)
   377  		if err != nil {
   378  			log.Fatalf("failed to select number of accounts with session key: %s", err)
   379  		} else if numAccounts == 0 {
   380  			return sessionKey
   381  		}
   382  	}
   383  }
   384  
   385  func (db *DB) Exec(sql string, arguments ...any) (pgconn.CommandTag, error) {
   386  	return db.conn.Exec(context.Background(), sql, arguments...)
   387  }
   388  
   389  func (db *DB) QueryRow(sql string, arguments ...any) pgx.Row {
   390  	return db.conn.QueryRow(context.Background(), sql, arguments...)
   391  }
   392  
   393  // Validate database interface during compilation.
   394  var (
   395  	_ sriracha.DB = &DB{}
   396  )
   397  

View as plain text