1 package database
2
3 import (
4 "context"
5 "log"
6 "time"
7
8 . "codeberg.org/tslocum/sriracha/model"
9 "github.com/jackc/pgx/v5"
10 )
11
12 func (db *DB) AddBan(b *Ban) {
13 _, err := db.conn.Exec(context.Background(), "INSERT INTO ban VALUES (DEFAULT, $1, $2, $3, $4)",
14 b.IP,
15 time.Now().Unix(),
16 b.Expire,
17 b.Reason,
18 )
19 if err != nil {
20 log.Fatalf("failed to insert ban: %s", err)
21 }
22 err = db.conn.QueryRow(context.Background(), "SELECT id FROM ban WHERE ip = $1", b.IP).Scan(&b.ID)
23 if err != nil || b.ID == 0 {
24 log.Fatalf("failed to select id of inserted ban: %s", err)
25 }
26 }
27
28 func (db *DB) BanByID(id int) *Ban {
29 b := &Ban{}
30 err := scanBan(b, db.conn.QueryRow(context.Background(), "SELECT * FROM ban WHERE id = $1", id))
31 if err == pgx.ErrNoRows {
32 return nil
33 } else if err != nil {
34 log.Fatalf("failed to select ban: %s", err)
35 }
36 return b
37 }
38
39 func (db *DB) BanByIP(ip string) *Ban {
40 b := &Ban{}
41 err := scanBan(b, db.conn.QueryRow(context.Background(), "SELECT * FROM ban WHERE ip = $1", ip))
42 if err == pgx.ErrNoRows {
43 return nil
44 } else if err != nil {
45 log.Fatalf("failed to select ban: %s", err)
46 }
47 return b
48 }
49
50 func (db *DB) AllBans(rangeOnly bool) []*Ban {
51 if db.conn == nil {
52 return nil
53 }
54 var extra string
55 if rangeOnly {
56 extra = " WHERE ip LIKE 'r %'"
57 }
58 rows, err := db.conn.Query(context.Background(), "SELECT * FROM ban"+extra+" ORDER BY timestamp DESC")
59 if err != nil {
60 log.Fatalf("failed to select all bans: %s", err)
61 }
62 var bans []*Ban
63 for rows.Next() {
64 b := &Ban{}
65 err := scanBan(b, rows)
66 if err != nil {
67 return nil
68 }
69 bans = append(bans, b)
70 }
71 return bans
72 }
73
74 func (db *DB) UpdateBan(b *Ban) {
75 if b.ID <= 0 {
76 log.Fatalf("invalid ban ID %d", b.ID)
77 }
78 _, err := db.conn.Exec(context.Background(), "UPDATE ban SET expire = $1, reason = $2 WHERE id = $3",
79 b.Expire,
80 b.Reason,
81 b.ID,
82 )
83 if err != nil {
84 log.Fatalf("failed to update ban: %s", err)
85 }
86 }
87
88 func (db *DB) DeleteExpiredBans() int {
89 var deleted int
90 err := db.conn.QueryRow(context.Background(), "WITH deleted AS (DELETE FROM ban WHERE expire != 0 AND expire <= $1 RETURNING *) SELECT COUNT(*) FROM deleted", time.Now().Unix()).Scan(&deleted)
91 if err != nil {
92 log.Fatal(err)
93 }
94 return deleted
95 }
96
97 func (db *DB) DeleteBan(id int) {
98 if id == 0 {
99 return
100 }
101 _, err := db.conn.Exec(context.Background(), "DELETE FROM ban WHERE id = $1", id)
102 if err != nil {
103 log.Fatalf("failed to delete ban: %s", err)
104 }
105 }
106
107 func scanBan(b *Ban, row pgx.Row) error {
108 return row.Scan(
109 &b.ID,
110 &b.IP,
111 &b.Timestamp,
112 &b.Expire,
113 &b.Reason,
114 )
115 }
116
117 func (db *DB) AddFileBan(fileHash string) {
118 _, err := db.conn.Exec(context.Background(), "INSERT INTO banfile VALUES ($1) ON CONFLICT DO NOTHING", fileHash)
119 if err != nil {
120 log.Fatalf("failed to ban file: %s", err)
121 }
122 }
123
124 func (db *DB) FileBanned(fileHash string) bool {
125 var banned bool
126 err := db.conn.QueryRow(context.Background(), "SELECT true FROM banfile WHERE hash = $1", fileHash).Scan(&banned)
127 if err == pgx.ErrNoRows {
128 return false
129 } else if err != nil {
130 log.Fatalf("failed to check if file is banned: %s", err)
131 }
132 return banned
133 }
134
135 func (db *DB) LiftFileBan(fileHash string) {
136 _, err := db.conn.Exec(context.Background(), "DELETE FROM banfile WHERE hash = $1", fileHash)
137 if err != nil {
138 log.Fatalf("failed to lift file ban: %s", err)
139 }
140 }
141
View as plain text