pg.go (5156B)
1 package postgres 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 8 pgx "github.com/jackc/pgx/v5" 9 "github.com/jackc/pgx/v5/pgxpool" 10 11 "git.defalsify.org/vise.git/db" 12 ) 13 14 var ( 15 defaultTxOptions pgx.TxOptions 16 ) 17 18 type PgInterface interface { 19 BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error) 20 Close() 21 } 22 23 // pgDb is a Postgres backend implementation of the Db interface. 24 type pgDb struct { 25 *db.DbBase 26 conn PgInterface 27 schema string 28 prefix uint8 29 prepd bool 30 it pgx.Rows 31 itBase []byte 32 tx pgx.Tx 33 multi bool 34 } 35 36 // NewpgDb creates a new Postgres backed Db implementation. 37 func NewPgDb() *pgDb { 38 db := &pgDb{ 39 DbBase: db.NewDbBase(), 40 schema: "public", 41 } 42 return db 43 } 44 45 // Base implements Db 46 func (pdb *pgDb) Base() *db.DbBase { 47 return pdb.DbBase 48 } 49 50 // WithSchema sets the Postgres schema to use for the storage table. 51 func (pdb *pgDb) WithSchema(schema string) *pgDb { 52 pdb.schema = schema 53 return pdb 54 } 55 56 func (pdb *pgDb) WithConnection(pi PgInterface) *pgDb { 57 pdb.conn = pi 58 return pdb 59 } 60 61 // Connect implements Db. 62 func (pdb *pgDb) Connect(ctx context.Context, connStr string) error { 63 if pdb.conn != nil { 64 logg.WarnCtxf(ctx, "Pg already connected") 65 return nil 66 } 67 conn, err := pgxpool.New(ctx, connStr) 68 if err != nil { 69 return err 70 } 71 72 if err := conn.Ping(ctx); err != nil { 73 return fmt.Errorf("connection to postgres could not be established: %w", err) 74 } 75 76 pdb.conn = conn 77 pdb.DbBase.Connect(ctx, connStr) 78 return pdb.ensureTable(ctx) 79 } 80 81 func (pdb *pgDb) Start(ctx context.Context) error { 82 if pdb.tx != nil { 83 return db.ErrTxExist 84 } 85 err := pdb.start(ctx) 86 if err != nil { 87 return err 88 } 89 pdb.multi = true 90 return nil 91 } 92 93 func (pdb *pgDb) start(ctx context.Context) error { 94 if pdb.tx != nil { 95 return nil 96 } 97 tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions) 98 logg.TraceCtxf(ctx, "begin single tx", "err", err) 99 if err != nil { 100 return err 101 } 102 pdb.tx = tx 103 return nil 104 } 105 106 func (pdb *pgDb) Stop(ctx context.Context) error { 107 if !pdb.multi { 108 return db.ErrSingleTx 109 } 110 return pdb.stop(ctx) 111 } 112 113 func (pdb *pgDb) stopSingle(ctx context.Context) error { 114 if pdb.multi { 115 return nil 116 } 117 err := pdb.tx.Commit(ctx) 118 logg.TraceCtxf(ctx, "stop single tx", "err", err) 119 pdb.tx = nil 120 return err 121 } 122 123 func (pdb *pgDb) stop(ctx context.Context) error { 124 if pdb.tx == nil { 125 return db.ErrNoTx 126 } 127 err := pdb.tx.Commit(ctx) 128 logg.TraceCtxf(ctx, "stop multi tx", "err", err) 129 pdb.tx = nil 130 return err 131 } 132 133 func (pdb *pgDb) Abort(ctx context.Context) { 134 logg.InfoCtxf(ctx, "aborting tx", "tx", pdb.tx) 135 pdb.tx.Rollback(ctx) 136 pdb.tx = nil 137 } 138 139 // Put implements Db. 140 func (pdb *pgDb) Put(ctx context.Context, key []byte, val []byte) error { 141 if !pdb.CheckPut() { 142 return errors.New("unsafe put and safety set") 143 } 144 145 lk, err := pdb.ToKey(ctx, key) 146 if err != nil { 147 return err 148 } 149 150 err = pdb.start(ctx) 151 if err != nil { 152 return err 153 } 154 logg.TraceCtxf(ctx, "put", "key", key, "val", val) 155 query := fmt.Sprintf("INSERT INTO %s.kv_vise (key, value, updated) VALUES ($1, $2, 'now') ON CONFLICT(key) DO UPDATE SET value = $2, updated = 'now';", pdb.schema) 156 actualKey := lk.Default 157 if lk.Translation != nil { 158 actualKey = lk.Translation 159 } 160 161 _, err = pdb.tx.Exec(ctx, query, actualKey, val) 162 if err != nil { 163 return err 164 } 165 166 return pdb.stopSingle(ctx) 167 } 168 169 // Get implements Db. 170 func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) { 171 var rr []byte 172 lk, err := pdb.ToKey(ctx, key) 173 if err != nil { 174 return nil, err 175 } 176 177 err = pdb.start(ctx) 178 if err != nil { 179 return nil, err 180 } 181 logg.TraceCtxf(ctx, "get", "key", key) 182 183 if lk.Translation != nil { 184 query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema) 185 rs, err := pdb.tx.Query(ctx, query, lk.Translation) 186 if err != nil { 187 pdb.Abort(ctx) 188 return nil, err 189 } 190 191 if rs.Next() { 192 err = rs.Scan(&rr) 193 if err != nil { 194 pdb.Abort(ctx) 195 return nil, err 196 } 197 198 rs.Close() 199 err = pdb.stopSingle(ctx) 200 return rr, err 201 } 202 } 203 204 query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema) 205 rs, err := pdb.tx.Query(ctx, query, lk.Default) 206 if err != nil { 207 pdb.Abort(ctx) 208 return nil, err 209 } 210 211 if !rs.Next() { 212 rs.Close() 213 return nil, db.NewErrNotFound(key) 214 } 215 216 err = rs.Scan(&rr) 217 if err != nil { 218 rs.Close() 219 pdb.Abort(ctx) 220 return nil, err 221 } 222 rs.Close() 223 err = pdb.stopSingle(ctx) 224 return rr, err 225 } 226 227 // Close implements Db. 228 func (pdb *pgDb) Close(ctx context.Context) error { 229 err := pdb.Stop(ctx) 230 if err == db.ErrNoTx { 231 err = nil 232 } 233 pdb.conn.Close() 234 return err 235 } 236 237 // set up table 238 func (pdb *pgDb) ensureTable(ctx context.Context) error { 239 if pdb.prepd { 240 logg.WarnCtxf(ctx, "ensureTable called more than once") 241 return nil 242 } 243 tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions) 244 if err != nil { 245 tx.Rollback(ctx) 246 return err 247 } 248 query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.kv_vise ( 249 id SERIAL NOT NULL, 250 key BYTEA NOT NULL UNIQUE, 251 value BYTEA NOT NULL, 252 updated TIMESTAMP NOT NULL 253 ); 254 `, pdb.schema) 255 _, err = tx.Exec(ctx, query) 256 if err != nil { 257 tx.Rollback(ctx) 258 return err 259 } 260 261 err = tx.Commit(ctx) 262 if err != nil { 263 tx.Rollback(ctx) 264 return err 265 } 266 pdb.prepd = true 267 return nil 268 }