go-vise

Constrained Size Output Virtual Machine
Info | Log | Files | Refs | README | LICENSE

pg.go (5173B)


      1 package postgres
      2 
      3 import (
      4 	"context"
      5 	"errors"
      6 	"fmt"
      7 
      8 	pgx "github.com/jackc/pgx/v5"
      9 	"github.com/jackc/pgx/v5/pgxpool"
     10 
     11 	"git.defalsify.org/vise.git/db"
     12 )
     13 
     14 var (
     15 	defaultTxOptions pgx.TxOptions
     16 )
     17 
     18 type PgInterface interface {
     19 	BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error)
     20 	Close()
     21 }
     22 
     23 // pgDb is a Postgres backend implementation of the Db interface.
     24 type pgDb struct {
     25 	*db.DbBase
     26 	conn   PgInterface
     27 	schema string
     28 	prefix uint8
     29 	prepd  bool
     30 	it     pgx.Rows
     31 	itBase []byte
     32 	tx     pgx.Tx
     33 	multi  bool
     34 }
     35 
     36 // NewpgDb creates a new Postgres backed Db implementation.
     37 func NewPgDb() *pgDb {
     38 	db := &pgDb{
     39 		DbBase: db.NewDbBase(),
     40 		schema: "public",
     41 	}
     42 	return db
     43 }
     44 
     45 // Base implements Db
     46 func (pdb *pgDb) Base() *db.DbBase {
     47 	return pdb.DbBase
     48 }
     49 
     50 // WithSchema sets the Postgres schema to use for the storage table.
     51 func (pdb *pgDb) WithSchema(schema string) *pgDb {
     52 	pdb.schema = schema
     53 	return pdb
     54 }
     55 
     56 func (pdb *pgDb) WithConnection(pi PgInterface) *pgDb {
     57 	pdb.conn = pi
     58 	return pdb
     59 }
     60 
     61 // Connect implements Db.
     62 func (pdb *pgDb) Connect(ctx context.Context, connStr string) error {
     63 	if pdb.conn != nil {
     64 		logg.WarnCtxf(ctx, "Pg already connected")
     65 		return nil
     66 	}
     67 	conn, err := pgxpool.New(ctx, connStr)
     68 	if err != nil {
     69 		return err
     70 	}
     71 
     72 	if err := conn.Ping(ctx); err != nil {
     73 		return fmt.Errorf("connection to postgres could not be established: %w", err)
     74 	}
     75 
     76 	pdb.conn = conn
     77 	pdb.DbBase.Connect(ctx, connStr)
     78 	return pdb.ensureTable(ctx)
     79 }
     80 
     81 func (pdb *pgDb) Start(ctx context.Context) error {
     82 	if pdb.tx != nil {
     83 		return db.ErrTxExist
     84 	}
     85 	err := pdb.start(ctx)
     86 	if err != nil {
     87 		return err
     88 	}
     89 	pdb.multi = true
     90 	return nil
     91 }
     92 
     93 func (pdb *pgDb) start(ctx context.Context) error {
     94 	if pdb.tx != nil {
     95 		return nil
     96 	}
     97 	tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
     98 	logg.TraceCtxf(ctx, "begin single tx", "err", err)
     99 	if err != nil {
    100 		return err
    101 	}
    102 	pdb.tx = tx
    103 	return nil
    104 }
    105 
    106 func (pdb *pgDb) Stop(ctx context.Context) error {
    107 	if !pdb.multi {
    108 		return db.ErrSingleTx
    109 	}
    110 	return pdb.stop(ctx)
    111 }
    112 
    113 func (pdb *pgDb) stopSingle(ctx context.Context) error {
    114 	if pdb.multi {
    115 		return nil
    116 	}
    117 	err := pdb.tx.Commit(ctx)
    118 	logg.TraceCtxf(ctx, "stop single tx", "err", err)
    119 	pdb.tx = nil
    120 	return err
    121 }
    122 
    123 func (pdb *pgDb) stop(ctx context.Context) error {
    124 	if pdb.tx == nil {
    125 		return db.ErrNoTx
    126 	}
    127 	err := pdb.tx.Commit(ctx)
    128 	logg.TraceCtxf(ctx, "stop multi tx", "err", err)
    129 	pdb.tx = nil
    130 	return err
    131 }
    132 
    133 func (pdb *pgDb) Abort(ctx context.Context) {
    134 	logg.InfoCtxf(ctx, "aborting tx", "tx", pdb.tx)
    135 	pdb.tx.Rollback(ctx)
    136 	pdb.tx = nil
    137 }
    138 
    139 // Put implements Db.
    140 func (pdb *pgDb) Put(ctx context.Context, key []byte, val []byte) error {
    141 	if !pdb.CheckPut() {
    142 		return errors.New("unsafe put and safety set")
    143 	}
    144 
    145 	lk, err := pdb.ToKey(ctx, key)
    146 	if err != nil {
    147 		return err
    148 	}
    149 
    150 	err = pdb.start(ctx)
    151 	if err != nil {
    152 		return err
    153 	}
    154 	logg.TraceCtxf(ctx, "put", "key", key, "val", val)
    155 	query := fmt.Sprintf("INSERT INTO %s.kv_vise (key, value, updated) VALUES ($1, $2, 'now') ON CONFLICT(key) DO UPDATE SET value = $2, updated = 'now';", pdb.schema)
    156 	actualKey := lk.Default
    157 	if lk.Translation != nil {
    158 		actualKey = lk.Translation
    159 	}
    160 
    161 	_, err = pdb.tx.Exec(ctx, query, actualKey, val)
    162 	if err != nil {
    163 		return err
    164 	}
    165 
    166 	return pdb.stopSingle(ctx)
    167 }
    168 
    169 // Get implements Db.
    170 func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) {
    171 	var rr []byte
    172 	lk, err := pdb.ToKey(ctx, key)
    173 	if err != nil {
    174 		return nil, err
    175 	}
    176 
    177 	err = pdb.start(ctx)
    178 	if err != nil {
    179 		return nil, err
    180 	}
    181 	logg.TraceCtxf(ctx, "get", "key", key)
    182 
    183 	if lk.Translation != nil {
    184 		query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema)
    185 		rs, err := pdb.tx.Query(ctx, query, lk.Translation)
    186 		if err != nil {
    187 			pdb.Abort(ctx)
    188 			return nil, err
    189 		}
    190 
    191 		if rs.Next() {
    192 			err = rs.Scan(&rr)
    193 			if err != nil {
    194 				pdb.Abort(ctx)
    195 				return nil, err
    196 			}
    197 
    198 			rs.Close()
    199 			err = pdb.stopSingle(ctx)
    200 			return rr, err
    201 		}
    202 	}
    203 
    204 	query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema)
    205 	rs, err := pdb.tx.Query(ctx, query, lk.Default)
    206 	if err != nil {
    207 		pdb.Abort(ctx)
    208 		return nil, err
    209 	}
    210 
    211 	if !rs.Next() {
    212 		rs.Close()
    213 		pdb.Abort(ctx)
    214 		return nil, db.NewErrNotFound(key)
    215 	}
    216 
    217 	err = rs.Scan(&rr)
    218 	if err != nil {
    219 		rs.Close()
    220 		pdb.Abort(ctx)
    221 		return nil, err
    222 	}
    223 	rs.Close()
    224 	err = pdb.stopSingle(ctx)
    225 	return rr, err
    226 }
    227 
    228 // Close implements Db.
    229 func (pdb *pgDb) Close(ctx context.Context) error {
    230 	err := pdb.Stop(ctx)
    231 	if err == db.ErrNoTx {
    232 		err = nil
    233 	}
    234 	pdb.conn.Close()
    235 	return err
    236 }
    237 
    238 // set up table
    239 func (pdb *pgDb) ensureTable(ctx context.Context) error {
    240 	if pdb.prepd {
    241 		logg.WarnCtxf(ctx, "ensureTable called more than once")
    242 		return nil
    243 	}
    244 	tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
    245 	if err != nil {
    246 		tx.Rollback(ctx)
    247 		return err
    248 	}
    249 	query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.kv_vise (
    250 		id SERIAL NOT NULL,
    251 		key BYTEA NOT NULL UNIQUE,
    252 		value BYTEA NOT NULL,
    253 		updated TIMESTAMP NOT NULL
    254 	);
    255 `, pdb.schema)
    256 	_, err = tx.Exec(ctx, query)
    257 	if err != nil {
    258 		tx.Rollback(ctx)
    259 		return err
    260 	}
    261 
    262 	err = tx.Commit(ctx)
    263 	if err != nil {
    264 		tx.Rollback(ctx)
    265 		return err
    266 	}
    267 	pdb.prepd = true
    268 	return nil
    269 }