pg.go (5092B)
1 package postgres 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 8 pgx "github.com/jackc/pgx/v5" 9 "github.com/jackc/pgx/v5/pgxpool" 10 11 "git.defalsify.org/vise.git/db" 12 ) 13 14 var ( 15 defaultTxOptions pgx.TxOptions 16 ) 17 18 type PgInterface interface { 19 BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error) 20 Close() 21 } 22 23 // pgDb is a Postgres backend implementation of the Db interface. 24 type pgDb struct { 25 *db.DbBase 26 conn PgInterface 27 schema string 28 prefix uint8 29 prepd bool 30 it pgx.Rows 31 itBase []byte 32 tx pgx.Tx 33 multi bool 34 } 35 36 // NewpgDb creates a new Postgres backed Db implementation. 37 func NewPgDb() *pgDb { 38 db := &pgDb{ 39 DbBase: db.NewDbBase(), 40 schema: "public", 41 } 42 return db 43 } 44 45 // WithSchema sets the Postgres schema to use for the storage table. 46 func (pdb *pgDb) WithSchema(schema string) *pgDb { 47 pdb.schema = schema 48 return pdb 49 } 50 51 func (pdb *pgDb) WithConnection(pi PgInterface) *pgDb { 52 pdb.conn = pi 53 return pdb 54 } 55 56 // Connect implements Db. 57 func (pdb *pgDb) Connect(ctx context.Context, connStr string) error { 58 if pdb.conn != nil { 59 logg.WarnCtxf(ctx, "Pg already connected") 60 return nil 61 } 62 conn, err := pgxpool.New(ctx, connStr) 63 if err != nil { 64 return err 65 } 66 67 if err := conn.Ping(ctx); err != nil { 68 return fmt.Errorf("connection to postgres could not be established: %w", err) 69 } 70 71 pdb.conn = conn 72 pdb.DbBase.Connect(ctx, connStr) 73 return pdb.ensureTable(ctx) 74 } 75 76 func (pdb *pgDb) Start(ctx context.Context) error { 77 if pdb.tx != nil { 78 return db.ErrTxExist 79 } 80 err := pdb.start(ctx) 81 if err != nil { 82 return err 83 } 84 pdb.multi = true 85 return nil 86 } 87 88 func (pdb *pgDb) start(ctx context.Context) error { 89 if pdb.tx != nil { 90 return nil 91 } 92 tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions) 93 logg.TraceCtxf(ctx, "begin single tx", "err", err) 94 if err != nil { 95 return err 96 } 97 pdb.tx = tx 98 return nil 99 } 100 101 func (pdb *pgDb) Stop(ctx context.Context) error { 102 if !pdb.multi { 103 return db.ErrSingleTx 104 } 105 return pdb.stop(ctx) 106 } 107 108 func (pdb *pgDb) stopSingle(ctx context.Context) error { 109 if pdb.multi { 110 return nil 111 } 112 err := pdb.tx.Commit(ctx) 113 logg.TraceCtxf(ctx, "stop single tx", "err", err) 114 pdb.tx = nil 115 return err 116 } 117 118 func (pdb *pgDb) stop(ctx context.Context) error { 119 if pdb.tx == nil { 120 return db.ErrNoTx 121 } 122 err := pdb.tx.Commit(ctx) 123 logg.TraceCtxf(ctx, "stop multi tx", "err", err) 124 pdb.tx = nil 125 return err 126 } 127 128 func (pdb *pgDb) Abort(ctx context.Context) { 129 logg.InfoCtxf(ctx, "aborting tx", "tx", pdb.tx) 130 pdb.tx.Rollback(ctx) 131 pdb.tx = nil 132 } 133 134 // Put implements Db. 135 func (pdb *pgDb) Put(ctx context.Context, key []byte, val []byte) error { 136 if !pdb.CheckPut() { 137 return errors.New("unsafe put and safety set") 138 } 139 140 lk, err := pdb.ToKey(ctx, key) 141 if err != nil { 142 return err 143 } 144 145 err = pdb.start(ctx) 146 if err != nil { 147 return err 148 } 149 logg.TraceCtxf(ctx, "put", "key", key, "val", val) 150 query := fmt.Sprintf("INSERT INTO %s.kv_vise (key, value, updated) VALUES ($1, $2, 'now') ON CONFLICT(key) DO UPDATE SET value = $2, updated = 'now';", pdb.schema) 151 actualKey := lk.Default 152 if lk.Translation != nil { 153 actualKey = lk.Translation 154 } 155 156 _, err = pdb.tx.Exec(ctx, query, actualKey, val) 157 if err != nil { 158 return err 159 } 160 161 return pdb.stopSingle(ctx) 162 } 163 164 // Get implements Db. 165 func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) { 166 var rr []byte 167 lk, err := pdb.ToKey(ctx, key) 168 if err != nil { 169 return nil, err 170 } 171 172 err = pdb.start(ctx) 173 if err != nil { 174 return nil, err 175 } 176 logg.TraceCtxf(ctx, "get", "key", key) 177 178 if lk.Translation != nil { 179 query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema) 180 rs, err := pdb.tx.Query(ctx, query, lk.Translation) 181 if err != nil { 182 pdb.Abort(ctx) 183 return nil, err 184 } 185 186 if rs.Next() { 187 err = rs.Scan(&rr) 188 if err != nil { 189 pdb.Abort(ctx) 190 return nil, err 191 } 192 193 rs.Close() 194 err = pdb.stopSingle(ctx) 195 return rr, err 196 } 197 } 198 199 query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema) 200 rs, err := pdb.tx.Query(ctx, query, lk.Default) 201 if err != nil { 202 pdb.Abort(ctx) 203 return nil, err 204 } 205 206 if !rs.Next() { 207 rs.Close() 208 pdb.Abort(ctx) 209 return nil, db.NewErrNotFound(key) 210 } 211 212 err = rs.Scan(&rr) 213 if err != nil { 214 rs.Close() 215 pdb.Abort(ctx) 216 return nil, err 217 } 218 rs.Close() 219 err = pdb.stopSingle(ctx) 220 return rr, err 221 } 222 223 // Close implements Db. 224 func (pdb *pgDb) Close(ctx context.Context) error { 225 err := pdb.Stop(ctx) 226 if err == db.ErrNoTx { 227 err = nil 228 } 229 pdb.conn.Close() 230 return err 231 } 232 233 // set up table 234 func (pdb *pgDb) ensureTable(ctx context.Context) error { 235 if pdb.prepd { 236 logg.WarnCtxf(ctx, "ensureTable called more than once") 237 return nil 238 } 239 tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions) 240 if err != nil { 241 tx.Rollback(ctx) 242 return err 243 } 244 query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.kv_vise ( 245 id SERIAL NOT NULL, 246 key BYTEA NOT NULL UNIQUE, 247 value BYTEA NOT NULL, 248 updated TIMESTAMP NOT NULL 249 ); 250 `, pdb.schema) 251 _, err = tx.Exec(ctx, query) 252 if err != nil { 253 tx.Rollback(ctx) 254 return err 255 } 256 257 err = tx.Commit(ctx) 258 if err != nil { 259 tx.Rollback(ctx) 260 return err 261 } 262 pdb.prepd = true 263 return nil 264 }