...

Source file src/codeberg.org/tslocum/sriracha/internal/database/database.go

Documentation: codeberg.org/tslocum/sriracha/internal/database

     1  // Package database provides methods for interacting with a Sriracha database.
     2  package database
     3  
     4  import (
     5  	"context"
     6  	"crypto/rand"
     7  	"encoding/base64"
     8  	"fmt"
     9  	"log"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"codeberg.org/tslocum/sriracha"
    15  	. "codeberg.org/tslocum/sriracha/util"
    16  	"github.com/alexedwards/argon2id"
    17  	"github.com/gabriel-vasile/mimetype"
    18  	"github.com/jackc/pgx/v5"
    19  	"github.com/jackc/pgx/v5/pgconn"
    20  	"github.com/jackc/pgx/v5/pgxpool"
    21  )
    22  
    23  var argon2idParameters = &argon2id.Params{
    24  	Memory:      128 * 1024,
    25  	Iterations:  2,
    26  	Parallelism: 2,
    27  	SaltLength:  16,
    28  	KeyLength:   64,
    29  }
    30  
    31  // DB represents a database connection.
    32  type DB struct {
    33  	conn      *pgxpool.Conn
    34  	Plugin    string
    35  	config    *Config
    36  	committed bool
    37  }
    38  
    39  func Connect(c *Config) (*pgxpool.Pool, error) {
    40  	url := c.DBURL
    41  	if strings.TrimSpace(url) == "" {
    42  		url = fmt.Sprintf("postgres://%s:%s@%s/%s", c.Username, c.Password, c.Address, c.DBName)
    43  	}
    44  
    45  	config, err := pgxpool.ParseConfig(url)
    46  	if err != nil {
    47  		return nil, fmt.Errorf("failed to parse database configuration: %s", err)
    48  	}
    49  	config.MinConns = 1
    50  	config.MinIdleConns = 1
    51  	config.MaxConns = 1
    52  
    53  	pool, err := pgxpool.NewWithConfig(context.Background(), config)
    54  	if err != nil {
    55  		return nil, fmt.Errorf("failed to connect to database: %s", err)
    56  	}
    57  
    58  	conn, err := pool.Acquire(context.Background())
    59  	if err != nil {
    60  		return nil, fmt.Errorf("failed to acquire conn: %s", err)
    61  	}
    62  	defer conn.Release()
    63  
    64  	_, err = conn.Exec(context.Background(), "BEGIN")
    65  	if err != nil {
    66  		return nil, fmt.Errorf("failed to begin transaction: %s", err)
    67  	}
    68  
    69  	db := &DB{
    70  		conn:   conn,
    71  		config: c,
    72  	}
    73  	err = db.initialize()
    74  	if err != nil {
    75  		return nil, fmt.Errorf("failed to initialize database: %s", err)
    76  	}
    77  
    78  	err = db.upgrade(c.Root)
    79  	if err != nil {
    80  		return nil, fmt.Errorf("failed to upgrade database: %s", err)
    81  	}
    82  
    83  	db.createSuperAdminAccount(c.SaltPass)
    84  
    85  	_, err = conn.Exec(context.Background(), "COMMIT")
    86  	if err != nil {
    87  		return nil, fmt.Errorf("failed to commit transaction: %s", err)
    88  	}
    89  	return pool, nil
    90  }
    91  
    92  func (db *DB) initialize() error {
    93  	_, err := db.conn.Exec(context.Background(), "SELECT 1=1")
    94  	if err != nil {
    95  		return fmt.Errorf("failed to test database connection: %s", err)
    96  	}
    97  
    98  	var tablecount int
    99  	err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'account'").Scan(&tablecount)
   100  	if err != nil {
   101  		return fmt.Errorf("failed to select whether account table exists: %s", err)
   102  	} else if tablecount > 0 {
   103  		return nil
   104  	}
   105  
   106  	fmt.Printf("Initializing database version 1...\n")
   107  	_, err = db.conn.Exec(context.Background(), dbSchema[0])
   108  	if err != nil {
   109  		return fmt.Errorf("failed to create database: %s", err)
   110  	}
   111  	fmt.Printf("Database initialized.\n")
   112  	return nil
   113  }
   114  
   115  func (db *DB) _upgrade(rootDir string, v int) error {
   116  	_, err := db.conn.Exec(context.Background(), dbSchema[v-1])
   117  	if err != nil {
   118  		return err
   119  	}
   120  	switch v {
   121  	case 5: // Add file MIME type to posts.
   122  		boards := db.AllBoards()
   123  		for _, b := range boards {
   124  			allThreads := db.AllThreads(b, false)
   125  			for _, threadInfo := range allThreads {
   126  				posts := db.AllPostsInThread(threadInfo[0], false)
   127  				for _, post := range posts {
   128  					if post.File != "" && !post.IsEmbed() {
   129  						if strings.HasSuffix(post.File, ".tgkr") {
   130  							post.FileMIME = "application/x-tegaki"
   131  						} else {
   132  							mimeInfo, err := mimetype.DetectFile(filepath.Join(rootDir, b.Dir, "src", post.File))
   133  							if err == nil {
   134  								post.FileMIME = mimeInfo.String()
   135  							}
   136  						}
   137  						if post.FileMIME != "" {
   138  							_, err = db.conn.Exec(context.Background(), "UPDATE post SET filemime = $1 WHERE id = $2", post.FileMIME, post.ID)
   139  							if err != nil {
   140  								return err
   141  							}
   142  						}
   143  					}
   144  				}
   145  			}
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func (db *DB) upgrade(rootDir string) error {
   152  	var versionString string
   153  	err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = 'version'").Scan(&versionString)
   154  	if err != nil {
   155  		return fmt.Errorf("failed to select database version: %s", err)
   156  	}
   157  	version, err := strconv.Atoi(versionString)
   158  	if err != nil {
   159  		return fmt.Errorf("failed to parse database version: %s", err)
   160  	}
   161  	maxVersion := len(dbSchema)
   162  	if version == maxVersion {
   163  		return nil
   164  	} else if version > maxVersion {
   165  		return fmt.Errorf("database version %d is newer than application version %d", version, maxVersion)
   166  	}
   167  	fmt.Printf("Upgrading database from version %d to %d...\n", version, maxVersion)
   168  	for v := version + 1; v <= maxVersion; v++ {
   169  		err = db._upgrade(rootDir, v)
   170  		if err != nil {
   171  			return fmt.Errorf("failed to upgrade database from version %d to version %d: %s", v-1, v, err)
   172  		}
   173  	}
   174  	fmt.Printf("Database upgraded.\n")
   175  	return nil
   176  }
   177  
   178  func Begin(pool *pgxpool.Pool, config *Config) *DB {
   179  	if pool == nil {
   180  		// Return mock database.
   181  		return &DB{
   182  			config: config,
   183  		}
   184  	}
   185  
   186  	conn, err := pool.Acquire(context.Background())
   187  	if err != nil {
   188  		log.Fatalf("failed to acquire connection from pool: %s", err)
   189  	}
   190  
   191  	_, err = conn.Exec(context.Background(), "BEGIN")
   192  	if err != nil {
   193  		conn.Release()
   194  		log.Fatalf("failed to begin transaction: %s", err)
   195  	}
   196  
   197  	return &DB{
   198  		conn:   conn,
   199  		config: config,
   200  	}
   201  }
   202  
   203  func (db *DB) SoftRollBack() {
   204  	if db.conn == nil {
   205  		return
   206  	}
   207  	_, err := db.conn.Exec(context.Background(), "ROLLBACK")
   208  	if err != nil {
   209  		log.Fatalf("failed to rollback transaction: %s", err)
   210  	}
   211  	_, err = db.conn.Exec(context.Background(), "BEGIN")
   212  	if err != nil {
   213  		log.Fatalf("failed to begin transaction: %s", err)
   214  	}
   215  }
   216  
   217  func (db *DB) RollBack() {
   218  	if db.conn == nil {
   219  		return
   220  	}
   221  	_, err := db.conn.Exec(context.Background(), "ROLLBACK")
   222  	if err != nil {
   223  		log.Fatalf("failed to rollback transaction: %s", err)
   224  	}
   225  	db.conn.Release()
   226  	db.committed = true
   227  }
   228  
   229  func (db *DB) Commit() {
   230  	if db.conn == nil || db.committed {
   231  		return
   232  	}
   233  	_, err := db.conn.Exec(context.Background(), "COMMIT")
   234  	if err != nil {
   235  		log.Fatalf("failed to commit transaction: %s", err)
   236  	}
   237  	db.conn.Release()
   238  	db.committed = true
   239  }
   240  
   241  func (db *DB) CommitWithErr() error {
   242  	if db.conn == nil || db.committed {
   243  		return nil
   244  	}
   245  	_, err := db.conn.Exec(context.Background(), "COMMIT")
   246  	if err != nil {
   247  		return fmt.Errorf("failed to commit transaction: %s", err)
   248  	}
   249  	db.conn.Release()
   250  	db.committed = true
   251  	return nil
   252  }
   253  
   254  func (db *DB) configKey(key string) string {
   255  	key = strings.ToLower(key)
   256  	if len(db.Plugin) != 0 {
   257  		return db.Plugin + "." + key
   258  	}
   259  	return key
   260  }
   261  
   262  func (db *DB) HaveConfig(key string) bool {
   263  	if db.conn == nil {
   264  		return false
   265  	}
   266  	key = db.configKey(key)
   267  	var count int
   268  	err := db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM config WHERE name = $1", key).Scan(&count)
   269  	if err == pgx.ErrNoRows {
   270  		return false
   271  	} else if err != nil {
   272  		log.Fatalf("failed to select config count %s: %s", key, err)
   273  	}
   274  	return count > 0
   275  }
   276  
   277  func (db *DB) GetString(key string) string {
   278  	if db.conn == nil {
   279  		return ""
   280  	}
   281  	key = db.configKey(key)
   282  	var value string
   283  	err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = $1", key).Scan(&value)
   284  	if err == pgx.ErrNoRows {
   285  		return ""
   286  	} else if err != nil {
   287  		log.Fatalf("failed to get string %s: %s", key, err)
   288  	}
   289  	return value
   290  }
   291  
   292  func (db *DB) SaveString(key string, value string) {
   293  	if db.conn == nil {
   294  		return
   295  	}
   296  	value = strings.ReplaceAll(value, "\r", "")
   297  	_, err := db.conn.Exec(context.Background(), "INSERT INTO config VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET value = $3", db.configKey(key), value, value)
   298  	if err != nil {
   299  		log.Fatalf("failed to save string: %s", err)
   300  	}
   301  }
   302  
   303  func (db *DB) GetMultiString(key string) []string {
   304  	return strings.Split(db.GetString(key), "|||")
   305  }
   306  
   307  func (db *DB) GetBool(key string) bool {
   308  	return db.GetString(key) == "1"
   309  }
   310  
   311  func (db *DB) SaveBool(key string, value bool) {
   312  	v := "0"
   313  	if value {
   314  		v = "1"
   315  	}
   316  	db.SaveString(key, v)
   317  }
   318  
   319  func (db *DB) SaveMultiString(key string, value []string) {
   320  	db.SaveString(key, strings.Join(value, "|||"))
   321  }
   322  
   323  func (db *DB) GetInt(key string) int {
   324  	return ParseInt(db.GetString(key))
   325  }
   326  
   327  func (db *DB) SaveInt(key string, value int) {
   328  	db.SaveString(key, strconv.Itoa(value))
   329  }
   330  
   331  func (db *DB) GetInt64(key string) int64 {
   332  	return ParseInt64(db.GetString(key))
   333  }
   334  
   335  func (db *DB) SaveInt64(key string, value int64) {
   336  	db.SaveString(key, fmt.Sprintf("%d", value))
   337  }
   338  
   339  func (db *DB) GetMultiInt(key string) []int {
   340  	s := db.GetString(key)
   341  	if s == "" {
   342  		return nil
   343  	}
   344  	var values []int
   345  	for _, v := range strings.Split(s, "|||") {
   346  		values = append(values, ParseInt(v))
   347  	}
   348  	return values
   349  }
   350  
   351  func (db *DB) SaveMultiInt(key string, values []int) {
   352  	var out string
   353  	for i, v := range values {
   354  		if i != 0 {
   355  			out += "|||"
   356  		}
   357  		out += strconv.Itoa(v)
   358  	}
   359  	db.SaveString(key, out)
   360  }
   361  
   362  func (db *DB) GetFloat(key string) float64 {
   363  	return ParseFloat(db.GetString(key))
   364  }
   365  
   366  func (db *DB) SaveFloat(key string, value float64) {
   367  	db.SaveString(key, fmt.Sprintf("%f", value))
   368  }
   369  
   370  func (db *DB) newSessionKey() string {
   371  	const keyLength = 48
   372  	buf := make([]byte, keyLength)
   373  	for {
   374  		_, err := rand.Read(buf)
   375  		if err != nil {
   376  			panic(err)
   377  		}
   378  		sessionKey := base64.URLEncoding.EncodeToString(buf)
   379  
   380  		var numAccounts int
   381  		err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM account WHERE session = $1", sessionKey).Scan(&numAccounts)
   382  		if err != nil {
   383  			log.Fatalf("failed to select number of accounts with session key: %s", err)
   384  		} else if numAccounts == 0 {
   385  			return sessionKey
   386  		}
   387  	}
   388  }
   389  
   390  func (db *DB) Exec(sql string, arguments ...any) (pgconn.CommandTag, error) {
   391  	return db.conn.Exec(context.Background(), sql, arguments...)
   392  }
   393  
   394  func (db *DB) QueryRow(sql string, arguments ...any) pgx.Row {
   395  	return db.conn.QueryRow(context.Background(), sql, arguments...)
   396  }
   397  
   398  // Validate database interface during compilation.
   399  var (
   400  	_ sriracha.DB = &DB{}
   401  )
   402  

View as plain text