1
2 package database
3
4 import (
5 "context"
6 "crypto/rand"
7 "encoding/base64"
8 "fmt"
9 "log"
10 "path/filepath"
11 "strconv"
12 "strings"
13
14 "codeberg.org/tslocum/sriracha"
15 . "codeberg.org/tslocum/sriracha/util"
16 "github.com/alexedwards/argon2id"
17 "github.com/gabriel-vasile/mimetype"
18 "github.com/jackc/pgx/v5"
19 "github.com/jackc/pgx/v5/pgconn"
20 "github.com/jackc/pgx/v5/pgxpool"
21 )
22
23 var argon2idParameters = &argon2id.Params{
24 Memory: 128 * 1024,
25 Iterations: 2,
26 Parallelism: 2,
27 SaltLength: 16,
28 KeyLength: 64,
29 }
30
31
32 type DB struct {
33 conn *pgxpool.Conn
34 Plugin string
35 config *Config
36 committed bool
37 }
38
39 func Connect(c *Config) (*pgxpool.Pool, error) {
40 url := c.DBURL
41 if strings.TrimSpace(url) == "" {
42 url = fmt.Sprintf("postgres://%s:%s@%s/%s", c.Username, c.Password, c.Address, c.DBName)
43 }
44
45 config, err := pgxpool.ParseConfig(url)
46 if err != nil {
47 return nil, fmt.Errorf("failed to parse database configuration: %s", err)
48 }
49 config.MinConns = 1
50 config.MinIdleConns = 1
51 config.MaxConns = 1
52
53 pool, err := pgxpool.NewWithConfig(context.Background(), config)
54 if err != nil {
55 return nil, fmt.Errorf("failed to connect to database: %s", err)
56 }
57
58 conn, err := pool.Acquire(context.Background())
59 if err != nil {
60 return nil, fmt.Errorf("failed to acquire conn: %s", err)
61 }
62 defer conn.Release()
63
64 _, err = conn.Exec(context.Background(), "BEGIN")
65 if err != nil {
66 return nil, fmt.Errorf("failed to begin transaction: %s", err)
67 }
68
69 db := &DB{
70 conn: conn,
71 config: c,
72 }
73 err = db.initialize()
74 if err != nil {
75 return nil, fmt.Errorf("failed to initialize database: %s", err)
76 }
77
78 err = db.upgrade(c.Root)
79 if err != nil {
80 return nil, fmt.Errorf("failed to upgrade database: %s", err)
81 }
82
83 db.createSuperAdminAccount(c.SaltPass)
84
85 _, err = conn.Exec(context.Background(), "COMMIT")
86 if err != nil {
87 return nil, fmt.Errorf("failed to commit transaction: %s", err)
88 }
89 return pool, nil
90 }
91
92 func (db *DB) initialize() error {
93 _, err := db.conn.Exec(context.Background(), "SELECT 1=1")
94 if err != nil {
95 return fmt.Errorf("failed to test database connection: %s", err)
96 }
97
98 var tablecount int
99 err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'account'").Scan(&tablecount)
100 if err != nil {
101 return fmt.Errorf("failed to select whether account table exists: %s", err)
102 } else if tablecount > 0 {
103 return nil
104 }
105
106 fmt.Printf("Initializing database version 1...\n")
107 _, err = db.conn.Exec(context.Background(), dbSchema[0])
108 if err != nil {
109 return fmt.Errorf("failed to create database: %s", err)
110 }
111 fmt.Printf("Database initialized.\n")
112 return nil
113 }
114
115 func (db *DB) _upgrade(rootDir string, v int) error {
116 _, err := db.conn.Exec(context.Background(), dbSchema[v-1])
117 if err != nil {
118 return err
119 }
120 switch v {
121 case 5:
122 boards := db.AllBoards()
123 for _, b := range boards {
124 allThreads := db.AllThreads(b, false)
125 for _, threadInfo := range allThreads {
126 posts := db.AllPostsInThread(threadInfo[0], false)
127 for _, post := range posts {
128 if post.File != "" && !post.IsEmbed() {
129 if strings.HasSuffix(post.File, ".tgkr") {
130 post.FileMIME = "application/x-tegaki"
131 } else {
132 mimeInfo, err := mimetype.DetectFile(filepath.Join(rootDir, b.Dir, "src", post.File))
133 if err == nil {
134 post.FileMIME = mimeInfo.String()
135 }
136 }
137 if post.FileMIME != "" {
138 _, err = db.conn.Exec(context.Background(), "UPDATE post SET filemime = $1 WHERE id = $2", post.FileMIME, post.ID)
139 if err != nil {
140 return err
141 }
142 }
143 }
144 }
145 }
146 }
147 }
148 return nil
149 }
150
151 func (db *DB) upgrade(rootDir string) error {
152 var versionString string
153 err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = 'version'").Scan(&versionString)
154 if err != nil {
155 return fmt.Errorf("failed to select database version: %s", err)
156 }
157 version, err := strconv.Atoi(versionString)
158 if err != nil {
159 return fmt.Errorf("failed to parse database version: %s", err)
160 }
161 maxVersion := len(dbSchema)
162 if version == maxVersion {
163 return nil
164 } else if version > maxVersion {
165 return fmt.Errorf("database version %d is newer than application version %d", version, maxVersion)
166 }
167 fmt.Printf("Upgrading database from version %d to %d...\n", version, maxVersion)
168 for v := version + 1; v <= maxVersion; v++ {
169 err = db._upgrade(rootDir, v)
170 if err != nil {
171 return fmt.Errorf("failed to upgrade database from version %d to version %d: %s", v-1, v, err)
172 }
173 }
174 fmt.Printf("Database upgraded.\n")
175 return nil
176 }
177
178 func Begin(pool *pgxpool.Pool, config *Config) *DB {
179 if pool == nil {
180
181 return &DB{
182 config: config,
183 }
184 }
185
186 conn, err := pool.Acquire(context.Background())
187 if err != nil {
188 log.Fatalf("failed to acquire connection from pool: %s", err)
189 }
190
191 _, err = conn.Exec(context.Background(), "BEGIN")
192 if err != nil {
193 conn.Release()
194 log.Fatalf("failed to begin transaction: %s", err)
195 }
196
197 return &DB{
198 conn: conn,
199 config: config,
200 }
201 }
202
203 func (db *DB) SoftRollBack() {
204 if db.conn == nil {
205 return
206 }
207 _, err := db.conn.Exec(context.Background(), "ROLLBACK")
208 if err != nil {
209 log.Fatalf("failed to rollback transaction: %s", err)
210 }
211 }
212
213 func (db *DB) RollBack() {
214 if db.conn == nil {
215 return
216 }
217 _, err := db.conn.Exec(context.Background(), "ROLLBACK")
218 if err != nil {
219 log.Fatalf("failed to rollback transaction: %s", err)
220 }
221 db.conn.Release()
222 }
223
224 func (db *DB) Commit() {
225 if db.conn == nil || db.committed {
226 return
227 }
228 _, err := db.conn.Exec(context.Background(), "COMMIT")
229 if err != nil {
230 log.Fatalf("failed to commit transaction: %s", err)
231 }
232 db.conn.Release()
233 db.committed = true
234 }
235
236 func (db *DB) CommitWithErr() error {
237 if db.conn == nil || db.committed {
238 return nil
239 }
240 _, err := db.conn.Exec(context.Background(), "COMMIT")
241 if err != nil {
242 return fmt.Errorf("failed to commit transaction: %s", err)
243 }
244 db.conn.Release()
245 db.committed = true
246 return nil
247 }
248
249 func (db *DB) configKey(key string) string {
250 key = strings.ToLower(key)
251 if len(db.Plugin) != 0 {
252 return db.Plugin + "." + key
253 }
254 return key
255 }
256
257 func (db *DB) HaveConfig(key string) bool {
258 if db.conn == nil {
259 return false
260 }
261 key = db.configKey(key)
262 var count int
263 err := db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM config WHERE name = $1", key).Scan(&count)
264 if err == pgx.ErrNoRows {
265 return false
266 } else if err != nil {
267 log.Fatalf("failed to select config count %s: %s", key, err)
268 }
269 return count > 0
270 }
271
272 func (db *DB) GetString(key string) string {
273 if db.conn == nil {
274 return ""
275 }
276 key = db.configKey(key)
277 var value string
278 err := db.conn.QueryRow(context.Background(), "SELECT value FROM config WHERE name = $1", key).Scan(&value)
279 if err == pgx.ErrNoRows {
280 return ""
281 } else if err != nil {
282 log.Fatalf("failed to get string %s: %s", key, err)
283 }
284 return value
285 }
286
287 func (db *DB) SaveString(key string, value string) {
288 if db.conn == nil {
289 return
290 }
291 value = strings.ReplaceAll(value, "\r", "")
292 _, err := db.conn.Exec(context.Background(), "INSERT INTO config VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET value = $3", db.configKey(key), value, value)
293 if err != nil {
294 log.Fatalf("failed to save string: %s", err)
295 }
296 }
297
298 func (db *DB) GetMultiString(key string) []string {
299 return strings.Split(db.GetString(key), "|||")
300 }
301
302 func (db *DB) GetBool(key string) bool {
303 return db.GetString(key) == "1"
304 }
305
306 func (db *DB) SaveBool(key string, value bool) {
307 v := "0"
308 if value {
309 v = "1"
310 }
311 db.SaveString(key, v)
312 }
313
314 func (db *DB) SaveMultiString(key string, value []string) {
315 db.SaveString(key, strings.Join(value, "|||"))
316 }
317
318 func (db *DB) GetInt(key string) int {
319 return ParseInt(db.GetString(key))
320 }
321
322 func (db *DB) SaveInt(key string, value int) {
323 db.SaveString(key, strconv.Itoa(value))
324 }
325
326 func (db *DB) GetInt64(key string) int64 {
327 return ParseInt64(db.GetString(key))
328 }
329
330 func (db *DB) SaveInt64(key string, value int64) {
331 db.SaveString(key, fmt.Sprintf("%d", value))
332 }
333
334 func (db *DB) GetMultiInt(key string) []int {
335 s := db.GetString(key)
336 if s == "" {
337 return nil
338 }
339 var values []int
340 for _, v := range strings.Split(s, "|||") {
341 values = append(values, ParseInt(v))
342 }
343 return values
344 }
345
346 func (db *DB) SaveMultiInt(key string, values []int) {
347 var out string
348 for i, v := range values {
349 if i != 0 {
350 out += "|||"
351 }
352 out += strconv.Itoa(v)
353 }
354 db.SaveString(key, out)
355 }
356
357 func (db *DB) GetFloat(key string) float64 {
358 return ParseFloat(db.GetString(key))
359 }
360
361 func (db *DB) SaveFloat(key string, value float64) {
362 db.SaveString(key, fmt.Sprintf("%f", value))
363 }
364
365 func (db *DB) newSessionKey() string {
366 const keyLength = 48
367 buf := make([]byte, keyLength)
368 for {
369 _, err := rand.Read(buf)
370 if err != nil {
371 panic(err)
372 }
373 sessionKey := base64.URLEncoding.EncodeToString(buf)
374
375 var numAccounts int
376 err = db.conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM account WHERE session = $1", sessionKey).Scan(&numAccounts)
377 if err != nil {
378 log.Fatalf("failed to select number of accounts with session key: %s", err)
379 } else if numAccounts == 0 {
380 return sessionKey
381 }
382 }
383 }
384
385 func (db *DB) Exec(sql string, arguments ...any) (pgconn.CommandTag, error) {
386 return db.conn.Exec(context.Background(), sql, arguments...)
387 }
388
389 func (db *DB) QueryRow(sql string, arguments ...any) pgx.Row {
390 return db.conn.QueryRow(context.Background(), sql, arguments...)
391 }
392
393
394 var (
395 _ sriracha.DB = &DB{}
396 )
397
View as plain text