pg.go (5173B)
1 package postgres 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 8 pgx "github.com/jackc/pgx/v5" 9 "github.com/jackc/pgx/v5/pgxpool" 10 11 "git.defalsify.org/vise.git/db" 12 ) 13 14 var ( 15 defaultTxOptions pgx.TxOptions 16 ) 17 18 type PgInterface interface { 19 BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error) 20 Close() 21 } 22 23 // pgDb is a Postgres backend implementation of the Db interface. 24 type pgDb struct { 25 *db.DbBase 26 conn PgInterface 27 schema string 28 prefix uint8 29 prepd bool 30 it pgx.Rows 31 itBase []byte 32 tx pgx.Tx 33 multi bool 34 } 35 36 // NewpgDb creates a new Postgres backed Db implementation. 37 func NewPgDb() *pgDb { 38 db := &pgDb{ 39 DbBase: db.NewDbBase(), 40 schema: "public", 41 } 42 return db 43 } 44 45 // Base implements Db 46 func (pdb *pgDb) Base() *db.DbBase { 47 return pdb.DbBase 48 } 49 50 // WithSchema sets the Postgres schema to use for the storage table. 51 func (pdb *pgDb) WithSchema(schema string) *pgDb { 52 pdb.schema = schema 53 return pdb 54 } 55 56 func (pdb *pgDb) WithConnection(pi PgInterface) *pgDb { 57 pdb.conn = pi 58 return pdb 59 } 60 61 // Connect implements Db. 62 func (pdb *pgDb) Connect(ctx context.Context, connStr string) error { 63 if pdb.conn != nil { 64 logg.WarnCtxf(ctx, "Pg already connected") 65 return nil 66 } 67 conn, err := pgxpool.New(ctx, connStr) 68 if err != nil { 69 return err 70 } 71 72 if err := conn.Ping(ctx); err != nil { 73 return fmt.Errorf("connection to postgres could not be established: %w", err) 74 } 75 76 pdb.conn = conn 77 pdb.DbBase.Connect(ctx, connStr) 78 return pdb.ensureTable(ctx) 79 } 80 81 func (pdb *pgDb) Start(ctx context.Context) error { 82 if pdb.tx != nil { 83 return db.ErrTxExist 84 } 85 err := pdb.start(ctx) 86 if err != nil { 87 return err 88 } 89 pdb.multi = true 90 return nil 91 } 92 93 func (pdb *pgDb) start(ctx context.Context) error { 94 if pdb.tx != nil { 95 return nil 96 } 97 tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions) 98 logg.TraceCtxf(ctx, "begin single tx", "err", err) 99 if err != nil { 100 return err 101 } 102 pdb.tx = tx 103 return nil 104 } 105 106 func (pdb *pgDb) Stop(ctx context.Context) error { 107 if !pdb.multi { 108 return db.ErrSingleTx 109 } 110 return pdb.stop(ctx) 111 } 112 113 func (pdb *pgDb) stopSingle(ctx context.Context) error { 114 if pdb.multi { 115 return nil 116 } 117 err := pdb.tx.Commit(ctx) 118 logg.TraceCtxf(ctx, "stop single tx", "err", err) 119 pdb.tx = nil 120 return err 121 } 122 123 func (pdb *pgDb) stop(ctx context.Context) error { 124 if pdb.tx == nil { 125 return db.ErrNoTx 126 } 127 err := pdb.tx.Commit(ctx) 128 logg.TraceCtxf(ctx, "stop multi tx", "err", err) 129 pdb.tx = nil 130 return err 131 } 132 133 func (pdb *pgDb) Abort(ctx context.Context) { 134 logg.InfoCtxf(ctx, "aborting tx", "tx", pdb.tx) 135 pdb.tx.Rollback(ctx) 136 pdb.tx = nil 137 } 138 139 // Put implements Db. 140 func (pdb *pgDb) Put(ctx context.Context, key []byte, val []byte) error { 141 if !pdb.CheckPut() { 142 return errors.New("unsafe put and safety set") 143 } 144 145 lk, err := pdb.ToKey(ctx, key) 146 if err != nil { 147 return err 148 } 149 150 err = pdb.start(ctx) 151 if err != nil { 152 return err 153 } 154 logg.TraceCtxf(ctx, "put", "key", key, "val", val) 155 query := fmt.Sprintf("INSERT INTO %s.kv_vise (key, value, updated) VALUES ($1, $2, 'now') ON CONFLICT(key) DO UPDATE SET value = $2, updated = 'now';", pdb.schema) 156 actualKey := lk.Default 157 if lk.Translation != nil { 158 actualKey = lk.Translation 159 } 160 161 _, err = pdb.tx.Exec(ctx, query, actualKey, val) 162 if err != nil { 163 return err 164 } 165 166 return pdb.stopSingle(ctx) 167 } 168 169 // Get implements Db. 170 func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) { 171 var rr []byte 172 lk, err := pdb.ToKey(ctx, key) 173 if err != nil { 174 return nil, err 175 } 176 177 err = pdb.start(ctx) 178 if err != nil { 179 return nil, err 180 } 181 logg.TraceCtxf(ctx, "get", "key", key) 182 183 if lk.Translation != nil { 184 query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema) 185 rs, err := pdb.tx.Query(ctx, query, lk.Translation) 186 if err != nil { 187 pdb.Abort(ctx) 188 return nil, err 189 } 190 191 if rs.Next() { 192 err = rs.Scan(&rr) 193 if err != nil { 194 pdb.Abort(ctx) 195 return nil, err 196 } 197 198 rs.Close() 199 err = pdb.stopSingle(ctx) 200 return rr, err 201 } 202 } 203 204 query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema) 205 rs, err := pdb.tx.Query(ctx, query, lk.Default) 206 if err != nil { 207 pdb.Abort(ctx) 208 return nil, err 209 } 210 211 if !rs.Next() { 212 rs.Close() 213 pdb.Abort(ctx) 214 return nil, db.NewErrNotFound(key) 215 } 216 217 err = rs.Scan(&rr) 218 if err != nil { 219 rs.Close() 220 pdb.Abort(ctx) 221 return nil, err 222 } 223 rs.Close() 224 err = pdb.stopSingle(ctx) 225 return rr, err 226 } 227 228 // Close implements Db. 229 func (pdb *pgDb) Close(ctx context.Context) error { 230 err := pdb.Stop(ctx) 231 if err == db.ErrNoTx { 232 err = nil 233 } 234 pdb.conn.Close() 235 return err 236 } 237 238 // set up table 239 func (pdb *pgDb) ensureTable(ctx context.Context) error { 240 if pdb.prepd { 241 logg.WarnCtxf(ctx, "ensureTable called more than once") 242 return nil 243 } 244 tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions) 245 if err != nil { 246 tx.Rollback(ctx) 247 return err 248 } 249 query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.kv_vise ( 250 id SERIAL NOT NULL, 251 key BYTEA NOT NULL UNIQUE, 252 value BYTEA NOT NULL, 253 updated TIMESTAMP NOT NULL 254 ); 255 `, pdb.schema) 256 _, err = tx.Exec(ctx, query) 257 if err != nil { 258 tx.Rollback(ctx) 259 return err 260 } 261 262 err = tx.Commit(ctx) 263 if err != nil { 264 tx.Rollback(ctx) 265 return err 266 } 267 pdb.prepd = true 268 return nil 269 }