commit 8ad1d92d2127cb1ee5734804cf56ce99b73d8c2b
parent 671c69442487ac81337a5ce0ab2a3303ae6b1bef
Author: lash <dev@holbrook.no>
Date: Sun, 19 Jan 2025 08:42:38 +0000
Add tests for commit, rollback postgres
Diffstat:
3 files changed, 217 insertions(+), 11 deletions(-)
diff --git a/db/db.go b/db/db.go
@@ -79,6 +79,7 @@ type Db interface {
DecodeKey(ctx context.Context, key []byte) ([]byte, error)
Start(context.Context) error
Stop(context.Context) error
+ Abort(context.Context)
}
type LookupKey struct {
@@ -269,3 +270,6 @@ func (bd *DbBase) Start(ctx context.Context) error {
func (bd *DbBase) Stop(ctx context.Context) error {
return nil
}
+
+func (bd *DbBase) Abort(ctx context.Context) {
+}
diff --git a/db/postgres/pg.go b/db/postgres/pg.go
@@ -85,6 +85,9 @@ func (pdb *pgDb) Start(ctx context.Context) error {
}
func (pdb *pgDb) start(ctx context.Context) error {
+ if pdb.tx != nil {
+ return nil
+ }
tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
logg.TraceCtxf(ctx, "begin single tx", "err", err)
if err != nil {
@@ -121,7 +124,8 @@ func (pdb *pgDb) stop(ctx context.Context) error {
return err
}
-func (pdb *pgDb) abort(ctx context.Context) {
+func (pdb *pgDb) Abort(ctx context.Context) {
+ logg.InfoCtxf(ctx, "aborting tx", "tx", pdb.tx)
pdb.tx.Rollback(ctx)
}
@@ -171,7 +175,7 @@ func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) {
query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema)
rs, err := pdb.tx.Query(ctx, query, lk.Translation)
if err != nil {
- pdb.abort(ctx)
+ pdb.Abort(ctx)
return nil, err
}
defer rs.Close()
@@ -189,13 +193,13 @@ func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) {
query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema)
rs, err := pdb.tx.Query(ctx, query, lk.Default)
if err != nil {
- pdb.abort(ctx)
+ pdb.Abort(ctx)
return nil, err
}
defer rs.Close()
if !rs.Next() {
- pdb.abort(ctx)
+ pdb.Abort(ctx)
return nil, db.NewErrNotFound(key)
}
diff --git a/db/postgres/pg_test.go b/db/postgres/pg_test.go
@@ -15,6 +15,16 @@ import (
"git.defalsify.org/vise.git/db/dbtest"
)
+var (
+ typMap = pgtype.NewMap()
+
+ mockVfd = pgconn.FieldDescription{
+ Name: "value",
+ DataTypeOID: pgtype.ByteaOID,
+ Format: typMap.FormatCodeForOID(pgtype.ByteaOID),
+ }
+)
+
func TestCasesPg(t *testing.T) {
ctx := context.Background()
@@ -52,8 +62,6 @@ func TestPutGetPg(t *testing.T) {
dbi = store
_ = dbi
- typMap := pgtype.NewMap()
-
k := []byte("foo")
ks := append([]byte{db.DATATYPE_USERDATA}, []byte(ses)...)
ks = append(ks, []byte(".")...)
@@ -68,11 +76,6 @@ func TestPutGetPg(t *testing.T) {
t.Fatal(err)
}
- mockVfd := pgconn.FieldDescription{
- Name: "value",
- DataTypeOID: pgtype.ByteaOID,
- Format: typMap.FormatCodeForOID(pgtype.ByteaOID),
- }
row := pgxmock.NewRowsWithColumnDefinition(mockVfd)
row = row.AddRow(v)
mock.ExpectBegin()
@@ -128,3 +131,198 @@ func TestPutGetPg(t *testing.T) {
}
}
+
+func TestPostgresTxAbort(t *testing.T) {
+ var dbi db.Db
+ ses := "xyzzy"
+
+ mock, err := pgxmock.NewPool()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer mock.Close()
+
+ store := NewPgDb().WithConnection(mock).WithSchema("vvise")
+ store.SetPrefix(db.DATATYPE_USERDATA)
+ store.SetSession(ses)
+ ctx := context.Background()
+
+ dbi = store
+ _ = dbi
+
+ resInsert := pgxmock.NewResult("UPDATE", 1)
+ k := []byte("foo")
+ ks := append([]byte{db.DATATYPE_USERDATA}, []byte(ses)...)
+ ks = append(ks, []byte(".")...)
+ ks = append(ks, k...)
+ v := []byte("bar")
+ //mock.ExpectBegin()
+ mock.ExpectBeginTx(defaultTxOptions)
+ mock.ExpectExec("INSERT INTO vvise.kv_vise").WithArgs(ks, v).WillReturnResult(resInsert)
+ mock.ExpectRollback()
+ err = store.Start(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Put(ctx, k, v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ store.Abort(ctx)
+}
+
+func TestPostgresTxCommitOnClose(t *testing.T) {
+ var dbi db.Db
+ ses := "xyzzy"
+
+ mock, err := pgxmock.NewPool()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer mock.Close()
+
+ store := NewPgDb().WithConnection(mock).WithSchema("vvise")
+ store.SetPrefix(db.DATATYPE_USERDATA)
+ store.SetSession(ses)
+ ctx := context.Background()
+
+ dbi = store
+ _ = dbi
+
+ resInsert := pgxmock.NewResult("UPDATE", 1)
+ k := []byte("foo")
+ ks := append([]byte{db.DATATYPE_USERDATA}, []byte(ses)...)
+ ks = append(ks, []byte(".")...)
+ ks = append(ks, k...)
+ v := []byte("bar")
+
+ ktwo := []byte("blinky")
+ kstwo := append([]byte{db.DATATYPE_USERDATA}, []byte(ses)...)
+ kstwo = append(kstwo, []byte(".")...)
+ kstwo = append(kstwo, ktwo...)
+ vtwo := []byte("clyde")
+
+ mock.ExpectBeginTx(defaultTxOptions)
+ mock.ExpectExec("INSERT INTO vvise.kv_vise").WithArgs(ks, v).WillReturnResult(resInsert)
+ mock.ExpectExec("INSERT INTO vvise.kv_vise").WithArgs(kstwo, vtwo).WillReturnResult(resInsert)
+ mock.ExpectCommit()
+
+ err = store.Start(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Put(ctx, k, v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Put(ctx, ktwo, vtwo)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Close(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ row := pgxmock.NewRowsWithColumnDefinition(mockVfd)
+ row = row.AddRow(v)
+ mock.ExpectBeginTx(defaultTxOptions)
+ mock.ExpectQuery("SELECT value FROM vvise.kv_vise").WithArgs(ks).WillReturnRows(row)
+ mock.ExpectCommit()
+ row = pgxmock.NewRowsWithColumnDefinition(mockVfd)
+ row = row.AddRow(vtwo)
+ mock.ExpectBeginTx(defaultTxOptions)
+ mock.ExpectQuery("SELECT value FROM vvise.kv_vise").WithArgs(kstwo).WillReturnRows(row)
+ mock.ExpectCommit()
+
+ store = NewPgDb().WithConnection(mock).WithSchema("vvise")
+ store.SetPrefix(db.DATATYPE_USERDATA)
+ store.SetSession(ses)
+ v, err = store.Get(ctx, k)
+ if err != nil {
+ if !db.IsNotFound(err) {
+ t.Fatalf("get key one: %x", k)
+ }
+ }
+ v, err = store.Get(ctx, ktwo)
+ if err != nil {
+ if !db.IsNotFound(err) {
+ t.Fatalf("get key two: %x", ktwo)
+ }
+ }
+}
+
+func TestPostgresTxStartStop(t *testing.T) {
+ var dbi db.Db
+ ses := "xyzzy"
+
+ mock, err := pgxmock.NewPool()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer mock.Close()
+
+ store := NewPgDb().WithConnection(mock).WithSchema("vvise")
+ store.SetPrefix(db.DATATYPE_USERDATA)
+ store.SetSession(ses)
+ ctx := context.Background()
+
+ dbi = store
+ _ = dbi
+
+ resInsert := pgxmock.NewResult("UPDATE", 1)
+ k := []byte("inky")
+ ks := append([]byte{db.DATATYPE_USERDATA}, []byte(ses)...)
+ ks = append(ks, []byte(".")...)
+ ks = append(ks, k...)
+ v := []byte("pinky")
+
+ ktwo := []byte("blinky")
+ kstwo := append([]byte{db.DATATYPE_USERDATA}, []byte(ses)...)
+ kstwo = append(kstwo, []byte(".")...)
+ kstwo = append(kstwo, ktwo...)
+ vtwo := []byte("clyde")
+ mock.ExpectBeginTx(defaultTxOptions)
+ mock.ExpectExec("INSERT INTO vvise.kv_vise").WithArgs(ks, v).WillReturnResult(resInsert)
+ mock.ExpectExec("INSERT INTO vvise.kv_vise").WithArgs(kstwo, vtwo).WillReturnResult(resInsert)
+ mock.ExpectCommit()
+
+ row := pgxmock.NewRowsWithColumnDefinition(mockVfd)
+ row = row.AddRow(v)
+ mock.ExpectBeginTx(defaultTxOptions)
+ mock.ExpectQuery("SELECT value FROM vvise.kv_vise").WithArgs(ks).WillReturnRows(row)
+ row = pgxmock.NewRowsWithColumnDefinition(mockVfd)
+ row = row.AddRow(vtwo)
+ mock.ExpectQuery("SELECT value FROM vvise.kv_vise").WithArgs(kstwo).WillReturnRows(row)
+ mock.ExpectCommit()
+
+ err = store.Start(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Put(ctx, k, v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Put(ctx, ktwo, vtwo)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Stop(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ v, err = store.Get(ctx, k)
+ if err != nil {
+ t.Fatal(err)
+ }
+ v, err = store.Get(ctx, ktwo)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = store.Close(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+}