1
2 package database
3
4 import (
5 "context"
6 "crypto/rand"
7 "encoding/base64"
8 "fmt"
9 "log"
10 "path/filepath"
11 "strconv"
12 "strings"
13
14 "codeberg.org/tslocum/sriracha"
15 . "codeberg.org/tslocum/sriracha/util"
16 "github.com/alexedwards/argon2id"
17 "github.com/gabriel-vasile/mimetype"
18 "github.com/jackc/pgx/v5"
19 "github.com/jackc/pgx/v5/pgconn"
20 "github.com/jackc/pgx/v5/pgxpool"
21 )
22
23 var argon2idParameters = &argon2id.Params{
24 Memory: 128 * 1024,
25 Iterations: 2,
26 Parallelism: 2,
27 SaltLength: 16,
28 KeyLength: 64,
29 }
30
31
32 type DB struct {
33 conn *pgxpool.Conn
34 Plugin string
35 config *Config
36 committed bool
37 }
38
39 func Connect(c *Config) (*pgxpool.Pool, error) {
40 url := c.DBURL
41 if strings.TrimSpace(url) == "" {
42 url = fmt.Sprintf("postgres://%s:%s@%s/%s", c.Username, c.Password, c.Address, c.DBName)
43 }
44
45 config, err := pgxpool.ParseConfig(url)
46 if err != nil {
47 return nil, fmt.Errorf("failed to parse database configuration: %s", err)
48 }
49 config.MinConns = 1
50 config.MinIdleConns = 1
51 config.MaxConns = 1
52
53 pool, err := pgxpool.NewWithConfig(context.Background(), config)
54 if err != nil {
55 return nil, fmt.Errorf("failed to connect to database: %s", err)
56 }
57
58 conn, err := pool.Acquire(context.Background())
59 if err != nil {
60 return nil, fmt.Errorf("failed to acquire conn: %s", err)
61 }
62 defer conn.Release()
63
64 _, err = conn.Exec(context.Background(), "BEGIN")
65 if err != nil {
66 return nil, fmt.Errorf("failed to begin transaction: %s", err)
67 }
68
69 db := &DB{
70 conn: conn,
71 config: c,
72 }
73 err = db.initialize()
74 if err != nil {
75 return nil, fmt.Errorf("failed to initialize database: %s", err)
76 }
77
78 err = db.upgrade(c.Root)
79 if err != nil {
80 return nil, fmt.Errorf("failed to upgrade database: %s", err)
81 }
82
83 db.createSuperAdminAccount(c.SaltPass)
84
85 _, err = conn.Exec(context.Background(), "COMMIT")
86 if err != nil {
87 return nil, fmt.Errorf("failed to commit transaction: %s", err)
88 }
89 return pool, nil
90 }
91
92 func (db *DB) initialize() error {
93 _, err := db.conn.Exec(context.Background(), "SELECT 1=1")
94 if err != nil {
95 return fmt.Errorf("failed to test database connection: %s", err)
96 }
97
98 var tablecount int
99 err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'account'").Scan(&tablecount)
100 if err != nil {
101 return fmt.Errorf("failed to select whether account table exists: %s", err)
102 } else if tablecount > 0 {
103 return nil
104 }
105
106 fmt.Printf("Initializing database version 1...\n")
107 _, err = db.conn.Exec(context.Background(), dbSchema[0])
108 if err != nil {
109 return fmt.Errorf("failed to create database: %s", err)
110 }
111 fmt.Printf("Database initialized.\n")
112 return nil
113 }
114
115 func (db *DB) _upgrade(rootDir string, v int) error {
116 _, err := db.conn.Exec(context.Background(), dbSchema[v-1])
117 if err != nil {
118 return err
119 }
120 switch v {
121 case 5:
122 boards := db.AllBoards()
123 for _, b := range boards {
124 allThreads := db.AllThreads(b, false)
125 for _, threadInfo := range allThreads {
126 posts := db.AllPostsInThread(threadInfo[0], false)
127 for _, post := range posts {
128 if post.File != "" && !post.IsEmbed() {
129 if strings.HasSuffix(post.File, ".tgkr") {
130 post.FileMIME = "application/x-tegaki"
131 } else {
132 mimeInfo, err := mimetype.DetectFile(filepath.Join(rootDir, b.Dir, "src", post.File))
133 if err == nil {
134 post.FileMIME = mimeInfo.String()
135 }
136 }
137 if post.FileMIME != "" {
138 _, err = db.conn.Exec(context.Background(), "UPDATE post SET filemime = $1 WHERE id = $2", post.FileMIME, post.ID)
139 if err != nil {
140 return err
141 }
142 }
143 }
144 }
145 }
146 }
147 }
148 return nil
149 }
150
151 func (db *DB) upgrade(rootDir string) error {
152 var versionString string
153 err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = 'version'").Scan(&versionString)
154 if err != nil {
155 return fmt.Errorf("failed to select database version: %s", err)
156 }
157 version, err := strconv.Atoi(versionString)
158 if err != nil {
159 return fmt.Errorf("failed to parse database version: %s", err)
160 }
161 maxVersion := len(dbSchema)
162 if version == maxVersion {
163 return nil
164 } else if version > maxVersion {
165 return fmt.Errorf("database version %d is newer than application version %d", version, maxVersion)
166 }
167 fmt.Printf("Upgrading database from version %d to %d...\n", version, maxVersion)
168 for v := version + 1; v <= maxVersion; v++ {
169 err = db._upgrade(rootDir, v)
170 if err != nil {
171 return fmt.Errorf("failed to upgrade database from version %d to version %d: %s", v-1, v, err)
172 }
173 }
174 fmt.Printf("Database upgraded.\n")
175 return nil
176 }
177
178 func Begin(pool *pgxpool.Pool, config *Config) *DB {
179 if pool == nil {
180
181 return &DB{
182 config: config,
183 }
184 }
185
186 conn, err := pool.Acquire(context.Background())
187 if err != nil {
188 log.Fatalf("failed to acquire connection from pool: %s", err)
189 }
190
191 _, err = conn.Exec(context.Background(), "BEGIN")
192 if err != nil {
193 conn.Release()
194 log.Fatalf("failed to begin transaction: %s", err)
195 }
196
197 return &DB{
198 conn: conn,
199 config: config,
200 }
201 }
202
203 func (db *DB) SoftRollBack() {
204 if db.conn == nil {
205 return
206 }
207 _, err := db.conn.Exec(context.Background(), "ROLLBACK")
208 if err != nil {
209 log.Fatalf("failed to rollback transaction: %s", err)
210 }
211 _, err = db.conn.Exec(context.Background(), "BEGIN")
212 if err != nil {
213 log.Fatalf("failed to begin transaction: %s", err)
214 }
215 }
216
217 func (db *DB) RollBack() {
218 if db.conn == nil {
219 return
220 }
221 _, err := db.conn.Exec(context.Background(), "ROLLBACK")
222 if err != nil {
223 log.Fatalf("failed to rollback transaction: %s", err)
224 }
225 db.conn.Release()
226 db.committed = true
227 }
228
229 func (db *DB) Commit() {
230 if db.conn == nil || db.committed {
231 return
232 }
233 _, err := db.conn.Exec(context.Background(), "COMMIT")
234 if err != nil {
235 log.Fatalf("failed to commit transaction: %s", err)
236 }
237 db.conn.Release()
238 db.committed = true
239 }
240
241 func (db *DB) CommitWithErr() error {
242 if db.conn == nil || db.committed {
243 return nil
244 }
245 _, err := db.conn.Exec(context.Background(), "COMMIT")
246 if err != nil {
247 return fmt.Errorf("failed to commit transaction: %s", err)
248 }
249 db.conn.Release()
250 db.committed = true
251 return nil
252 }
253
254 func (db *DB) configKey(key string) string {
255 key = strings.ToLower(key)
256 if len(db.Plugin) != 0 {
257 return db.Plugin + "." + key
258 }
259 return key
260 }
261
262 func (db *DB) HaveConfig(key string) bool {
263 if db.conn == nil {
264 return false
265 }
266 key = db.configKey(key)
267 var count int
268 err := db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM config WHERE name = $1", key).Scan(&count)
269 if err == pgx.ErrNoRows {
270 return false
271 } else if err != nil {
272 log.Fatalf("failed to select config count %s: %s", key, err)
273 }
274 return count > 0
275 }
276
277 func (db *DB) GetString(key string) string {
278 if db.conn == nil {
279 return ""
280 }
281 key = db.configKey(key)
282 var value string
283 err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = $1", key).Scan(&value)
284 if err == pgx.ErrNoRows {
285 return ""
286 } else if err != nil {
287 log.Fatalf("failed to get string %s: %s", key, err)
288 }
289 return value
290 }
291
292 func (db *DB) SaveString(key string, value string) {
293 if db.conn == nil {
294 return
295 }
296 value = strings.ReplaceAll(value, "\r", "")
297 _, err := db.conn.Exec(context.Background(), "INSERT INTO config VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET value = $3", db.configKey(key), value, value)
298 if err != nil {
299 log.Fatalf("failed to save string: %s", err)
300 }
301 }
302
303 func (db *DB) GetMultiString(key string) []string {
304 return strings.Split(db.GetString(key), "|||")
305 }
306
307 func (db *DB) GetBool(key string) bool {
308 return db.GetString(key) == "1"
309 }
310
311 func (db *DB) SaveBool(key string, value bool) {
312 v := "0"
313 if value {
314 v = "1"
315 }
316 db.SaveString(key, v)
317 }
318
319 func (db *DB) SaveMultiString(key string, value []string) {
320 db.SaveString(key, strings.Join(value, "|||"))
321 }
322
323 func (db *DB) GetInt(key string) int {
324 return ParseInt(db.GetString(key))
325 }
326
327 func (db *DB) SaveInt(key string, value int) {
328 db.SaveString(key, strconv.Itoa(value))
329 }
330
331 func (db *DB) GetInt64(key string) int64 {
332 return ParseInt64(db.GetString(key))
333 }
334
335 func (db *DB) SaveInt64(key string, value int64) {
336 db.SaveString(key, fmt.Sprintf("%d", value))
337 }
338
339 func (db *DB) GetMultiInt(key string) []int {
340 s := db.GetString(key)
341 if s == "" {
342 return nil
343 }
344 var values []int
345 for _, v := range strings.Split(s, "|||") {
346 values = append(values, ParseInt(v))
347 }
348 return values
349 }
350
351 func (db *DB) SaveMultiInt(key string, values []int) {
352 var out string
353 for i, v := range values {
354 if i != 0 {
355 out += "|||"
356 }
357 out += strconv.Itoa(v)
358 }
359 db.SaveString(key, out)
360 }
361
362 func (db *DB) GetFloat(key string) float64 {
363 return ParseFloat(db.GetString(key))
364 }
365
366 func (db *DB) SaveFloat(key string, value float64) {
367 db.SaveString(key, fmt.Sprintf("%f", value))
368 }
369
370 func (db *DB) newSessionKey() string {
371 const keyLength = 48
372 buf := make([]byte, keyLength)
373 for {
374 _, err := rand.Read(buf)
375 if err != nil {
376 panic(err)
377 }
378 sessionKey := base64.URLEncoding.EncodeToString(buf)
379
380 var numAccounts int
381 err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM account WHERE session = $1", sessionKey).Scan(&numAccounts)
382 if err != nil {
383 log.Fatalf("failed to select number of accounts with session key: %s", err)
384 } else if numAccounts == 0 {
385 return sessionKey
386 }
387 }
388 }
389
390 func (db *DB) Exec(sql string, arguments ...any) (pgconn.CommandTag, error) {
391 return db.conn.Exec(context.Background(), sql, arguments...)
392 }
393
394 func (db *DB) QueryRow(sql string, arguments ...any) pgx.Row {
395 return db.conn.QueryRow(context.Background(), sql, arguments...)
396 }
397
398
399 var (
400 _ sriracha.DB = &DB{}
401 )
402
View as plain text