commit 671c69442487ac81337a5ce0ab2a3303ae6b1bef
parent 46dfca563b969d7bb4e4dc75f0d3311d41af5452
Author: lash <dev@holbrook.no>
Date: Sun, 19 Jan 2025 08:08:32 +0000
Add tx capability, context in db closers
Diffstat:
19 files changed, 123 insertions(+), 48 deletions(-)
diff --git a/db/db.go b/db/db.go
@@ -40,10 +40,8 @@ type Db interface {
//
// If called more than once, consecutive calls should be ignored.
Connect(ctx context.Context, connStr string) error
- // Close implements io.Closer.
- //
// MUST be called before termination after a Connect().
- Close() error
+ Close(context.Context) error
// Get retrieves the value belonging to a key.
//
// Errors if the key does not exist, or if the retrieval otherwise fails.
@@ -79,6 +77,8 @@ type Db interface {
Prefix() uint8
Dump(context.Context, []byte) (*Dumper, error)
DecodeKey(ctx context.Context, key []byte) ([]byte, error)
+ Start(context.Context) error
+ Stop(context.Context) error
}
type LookupKey struct {
@@ -261,3 +261,11 @@ func(bd *DbBase) DecodeKey(ctx context.Context, key []byte) ([]byte, error) {
logg.DebugCtxf(ctx, "decoded key", "key", key, "fromkey", oldKey)
return key, nil
}
+
+func (bd *DbBase) Start(ctx context.Context) error {
+ return nil
+}
+
+func (bd *DbBase) Stop(ctx context.Context) error {
+ return nil
+}
diff --git a/db/error.go b/db/error.go
@@ -1,6 +1,7 @@
package db
import (
+ "errors"
"fmt"
"strings"
)
@@ -9,6 +10,12 @@ const (
notFoundPrefix = "key not found: "
)
+var (
+ ErrTxExist = errors.New("tx already exists")
+ ErrNoTx = errors.New("tx does not exist")
+ ErrSingleTx = errors.New("not a multi-instruction tx")
+)
+
// ErrNotFound is returned with a key was successfully queried, but did not match a stored key.
type ErrNotFound struct {
k []byte
diff --git a/db/fs/fs.go b/db/fs/fs.go
@@ -153,7 +153,7 @@ func(fdb *fsDb) Put(ctx context.Context, key []byte, val []byte) error {
}
// Close implements the Db interface.
-func(fdb *fsDb) Close() error {
+func(fdb *fsDb) Close(ctx context.Context) error {
return nil
}
diff --git a/db/fs/fs_test.go b/db/fs/fs_test.go
@@ -152,7 +152,7 @@ func TestReopen(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- err = store.Close()
+ err = store.Close(ctx)
if err != nil {
t.Fatal(err)
}
diff --git a/db/gdbm/gdbm.go b/db/gdbm/gdbm.go
@@ -132,7 +132,7 @@ func(gdb *gdbmDb) Get(ctx context.Context, key []byte) ([]byte, error) {
}
// Close implements Db
-func(gdb *gdbmDb) Close() error {
- logg.Tracef("closing gdbm", "path", gdb.conn)
+func(gdb *gdbmDb) Close(ctx context.Context) error {
+ logg.TraceCtxf(ctx, "closing gdbm", "path", gdb.conn)
return gdb.conn.Close()
}
diff --git a/db/gdbm/gdbm_test.go b/db/gdbm/gdbm_test.go
@@ -103,7 +103,7 @@ func TestConnect(t *testing.T) {
if !store.CheckPut() {
t.Fatal("expected checkput false")
}
- err = store.Close()
+ err = store.Close(ctx)
if err != nil {
t.Fatal(err)
}
@@ -133,7 +133,7 @@ func TestReopen(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- err = store.Close()
+ err = store.Close(ctx)
if err != nil {
t.Fatal(err)
}
diff --git a/db/mem/mem.go b/db/mem/mem.go
@@ -99,6 +99,6 @@ func(mdb *memDb) Put(ctx context.Context, key []byte, val []byte) error {
}
// Close implements Db
-func(mdb *memDb) Close() error {
+func(mdb *memDb) Close(ctx context.Context) error {
return nil
}
diff --git a/db/postgres/pg.go b/db/postgres/pg.go
@@ -29,6 +29,8 @@ type pgDb struct {
prepd bool
it pgx.Rows
itBase []byte
+ tx pgx.Tx
+ multi bool
}
// NewpgDb creates a new Postgres backed Db implementation.
@@ -67,7 +69,60 @@ func(pdb *pgDb) Connect(ctx context.Context, connStr string) error {
}
pdb.conn = conn
- return pdb.Prepare(ctx)
+ return pdb.ensureTable(ctx)
+}
+
+func (pdb *pgDb) Start(ctx context.Context) error {
+ if pdb.tx != nil {
+ return db.ErrTxExist
+ }
+ err := pdb.start(ctx)
+ if err != nil {
+ return err
+ }
+ pdb.multi = true
+ return nil
+}
+
+func (pdb *pgDb) start(ctx context.Context) error {
+ tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
+ logg.TraceCtxf(ctx, "begin single tx", "err", err)
+ if err != nil {
+ return err
+ }
+ pdb.tx = tx
+ return nil
+}
+
+func (pdb *pgDb) Stop(ctx context.Context) error {
+ if !pdb.multi {
+ return db.ErrSingleTx
+ }
+ return pdb.stop(ctx)
+}
+
+func (pdb *pgDb) stopSingle(ctx context.Context) error {
+ if pdb.multi {
+ return nil
+ }
+ err := pdb.tx.Commit(ctx)
+ logg.TraceCtxf(ctx, "stop single tx", "err", err)
+ pdb.tx = nil
+ return err
+}
+
+func (pdb *pgDb) stop(ctx context.Context) error {
+ if pdb.tx == nil {
+ return db.ErrNoTx
+ }
+ err := pdb.tx.Commit(ctx)
+ logg.TraceCtxf(ctx, "stop multi tx", "err", err)
+ pdb.tx = nil
+ return err
+}
+
+func (pdb *pgDb) abort(ctx context.Context) {
+ pdb.tx.Rollback(ctx)
}
// Put implements Db.
@@ -81,7 +136,7 @@ func(pdb *pgDb) Put(ctx context.Context, key []byte, val []byte) error {
return err
}
- tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
+ err = pdb.start(ctx)
if err != nil {
return err
}
@@ -91,12 +146,12 @@ func(pdb *pgDb) Put(ctx context.Context, key []byte, val []byte) error {
actualKey = lk.Translation
}
- _, err = tx.Exec(ctx, query, actualKey, val)
+ _, err = pdb.tx.Exec(ctx, query, actualKey, val)
if err != nil {
return err
}
- return tx.Commit(ctx)
+ return pdb.stopSingle(ctx)
}
// Get implements Db.
@@ -106,56 +161,63 @@ func (pdb *pgDb) Get(ctx context.Context, key []byte) ([]byte, error) {
return nil, err
}
- tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
+ //tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
+ err = pdb.start(ctx)
if err != nil {
return nil, err
}
if lk.Translation != nil {
query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema)
- rs, err := tx.Query(ctx, query, lk.Translation)
+ rs, err := pdb.tx.Query(ctx, query, lk.Translation)
if err != nil {
- tx.Rollback(ctx)
+ pdb.abort(ctx)
return nil, err
}
defer rs.Close()
if rs.Next() {
+ // TODO: encode non raw
r := rs.RawValues()
- tx.Commit(ctx)
- tx.Rollback(ctx)
- return r[0], nil
+ //tx.Commit(ctx)
+ //tx.Rollback(ctx)
+ err = pdb.stopSingle(ctx)
+ return r[0], err
}
}
query := fmt.Sprintf("SELECT value FROM %s.kv_vise WHERE key = $1", pdb.schema)
- rs, err := tx.Query(ctx, query, lk.Default)
+ rs, err := pdb.tx.Query(ctx, query, lk.Default)
if err != nil {
- tx.Rollback(ctx)
+ pdb.abort(ctx)
return nil, err
}
defer rs.Close()
if !rs.Next() {
- tx.Rollback(ctx)
+ pdb.abort(ctx)
return nil, db.NewErrNotFound(key)
}
r := rs.RawValues()
- tx.Commit(ctx)
- return r[0], nil
+ err = pdb.stopSingle(ctx)
+ return r[0], err
}
// Close implements Db.
-func(pdb *pgDb) Close() error {
+func(pdb *pgDb) Close(ctx context.Context) error {
+ err := pdb.Stop(ctx)
+ if err == db.ErrNoTx {
+ err = nil
+ }
pdb.conn.Close()
- return nil
+ return err
}
// set up table
-func(pdb *pgDb) Prepare(ctx context.Context) error {
+func(pdb *pgDb) ensureTable(ctx context.Context) error {
if pdb.prepd {
- logg.WarnCtxf(ctx, "Prepare called more than once")
+ logg.WarnCtxf(ctx, "ensureTable called more than once")
return nil
}
tx, err := pdb.conn.BeginTx(ctx, defaultTxOptions)
diff --git a/dev/dbconvert/main.go b/dev/dbconvert/main.go
@@ -47,7 +47,7 @@ func newScanner(ctx context.Context, db db.Db) (*scanner, error) {
}
func(sc *scanner) Close() error {
- return sc.db.Close()
+ return sc.db.Close(sc.ctx)
}
func(sc *scanner) Scan(fp string, d fs.DirEntry, err error) error {
diff --git a/engine/db.go b/engine/db.go
@@ -367,7 +367,7 @@ func(en *DefaultEngine) runFirst(ctx context.Context) (bool, error) {
// An error will be logged and returned if:
// * persistence was attempted and failed (takes precedence)
// * resource backend did not close cleanly.
-func(en *DefaultEngine) Finish() error {
+func(en *DefaultEngine) Finish(ctx context.Context) error {
var perr error
if !en.initd {
return nil
@@ -375,7 +375,7 @@ func(en *DefaultEngine) Finish() error {
if en.pe != nil {
perr = en.pe.Save(en.cfg.SessionId)
}
- err := en.rs.Close()
+ err := en.rs.Close(ctx)
if err != nil {
logg.Errorf("resource close failed!", "err", err)
}
diff --git a/engine/db_test.go b/engine/db_test.go
@@ -64,7 +64,7 @@ func TestDbEngineMinimal(t *testing.T) {
if cont {
t.Fatalf("expected not continue")
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
t.Fatal(err)
}
@@ -205,7 +205,7 @@ func TestDbEngineRoot(t *testing.T) {
if err == nil {
t.Fatalf("expected nocode")
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
t.Fatal(err)
}
@@ -245,7 +245,7 @@ func TestDbEnginePersist(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
t.Fatal(err)
}
diff --git a/engine/engine.go b/engine/engine.go
@@ -16,5 +16,5 @@ type Engine interface {
// VM execution.
Flush(ctx context.Context, w io.Writer) (int, error)
// Finish must be called after the last call to Exec.
- Finish() error
+ Finish(ctx context.Context) error
}
diff --git a/engine/loop.go b/engine/loop.go
@@ -21,7 +21,7 @@ import (
// If initial is set, the value will be used for the first (initializing) execution
// If nil, an empty byte value will be used.
func Loop(ctx context.Context, en Engine, reader io.Reader, writer io.Writer, initial []byte) error {
- defer en.Finish()
+ defer en.Finish(ctx)
if initial == nil {
initial = []byte{}
}
diff --git a/engine/persist_test.go b/engine/persist_test.go
@@ -54,7 +54,7 @@ func TestPersistNewAcrossEngine(t *testing.T) {
t.Fatal(err)
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
t.Fatal(err)
}
@@ -119,7 +119,7 @@ func TestPersistSameAcrossEngine(t *testing.T) {
t.Fatal(err)
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
t.Fatal(err)
}
diff --git a/examples/first/main.go b/examples/first/main.go
@@ -180,7 +180,7 @@ func main() {
os.Exit(1)
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "engine finish fail: %v\n", err)
os.Exit(1)
diff --git a/examples/http/main.go b/examples/http/main.go
@@ -167,7 +167,7 @@ func(f *DefaultSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Reques
f.writeError(w, 500, "Write result fail", err)
return
}
- err = en.Finish()
+ err = en.Finish(ctx)
if err != nil {
f.writeError(w, 500, "Engine finish fail", err)
return
diff --git a/examples/longmenu/main.go b/examples/longmenu/main.go
@@ -32,7 +32,7 @@ func main() {
os.Exit(1)
}
rs := resource.NewDbResource(store)
- defer rs.Close()
+ defer rs.Close(ctx)
cfg := engine.Config {
OutputSize: uint32(size),
}
diff --git a/resource/db.go b/resource/db.go
@@ -156,6 +156,6 @@ func(g *DbResource) DbFuncFor(ctx context.Context, sym string) (EntryFunc, error
}
// Close implements the Resource interface.
-func(g *DbResource) Close() error {
- return g.db.Close()
+func(g *DbResource) Close(ctx context.Context) error {
+ return g.db.Close(ctx)
}
diff --git a/resource/resource.go b/resource/resource.go
@@ -45,10 +45,8 @@ type Resource interface {
GetMenu(ctx context.Context, menuSym string) (string, error)
// FuncFor retrieves the external function (EntryFunc) associated with the given symbol.
FuncFor(ctx context.Context, loadSym string) (EntryFunc, error)
- // Close implements the io.Closer interface.
- //
// Safely shuts down retrieval backend.
- Close() error
+ Close(ctx context.Context) error
}
// MenuResource contains the base definition for building Resource implementations.
@@ -144,6 +142,6 @@ func(m *MenuResource) FallbackFunc(ctx context.Context, sym string) (EntryFunc,
}
// Close implements the Resource interface.
-func(m *MenuResource) Close() error {
+func(m *MenuResource) Close(ctx context.Context) error {
return nil
}