commit fe4c7fb7c7edad6e8606692a77530e307b3e7c6c
parent f94d7ea8599743e5fcd4f4070f98bbb177613f53
Author: lash <dev@holbrook.no>
Date: Sat, 31 Aug 2024 17:08:31 +0100
Add state and cache invalidation to prevent invalid persist
Diffstat:
8 files changed, 140 insertions(+), 10 deletions(-)
diff --git a/cache/cache.go b/cache/cache.go
@@ -12,6 +12,7 @@ type Cache struct {
Cache []map[string]string // All loaded cache items
Sizes map[string]uint16 // Size limits for all loaded symbols.
LastValue string // last inserted value
+ invalid bool
}
// NewCache creates a new ready-to-use cache object
@@ -23,6 +24,20 @@ func NewCache() *Cache {
return ca
}
+// Invalidate marks a cache as invalid.
+//
+// An invalid cache should not be persisted or propagated
+func(ca *Cache) Invalidate() {
+ ca.invalid = true
+}
+
+// Invalid returns true if cache is invalid.
+//
+// An invalid cache should not be persisted or propagated
+func(ca *Cache) Invalid() bool {
+ return ca.invalid
+}
+
// WithCacheSize applies a cumulative cache size limitation for all cached items.
func(ca *Cache) WithCacheSize(cacheSize uint32) *Cache {
ca.CacheSize = cacheSize
diff --git a/cache/memory.go b/cache/memory.go
@@ -12,4 +12,6 @@ type Memory interface {
Levels() uint32
Keys(level uint32) []string
Last() string
+ Invalidate()
+ Invalid() bool
}
diff --git a/engine/engine.go b/engine/engine.go
@@ -117,16 +117,20 @@ func(en *Engine) runFirst(ctx context.Context) (bool, error) {
// TODO: typed error
err = fmt.Errorf("Pre-VM code cannot have remaining bytecode after execution, had: %x", b)
} else {
- if en.st.MatchFlag(state.FLAG_TERMINATE, true) {
+ if en.st.MatchFlag(state.FLAG_BLOCK, true) {
en.exit = en.ca.Last()
Logg.InfoCtxf(ctx, "Pre-VM check says not to continue execution", "state", en.st)
} else {
r = true
}
}
- en.st.ResetFlag(state.FLAG_TERMINATE)
+ if err != nil {
+ en.st.Invalidate()
+ en.ca.Invalidate()
+ }
+ en.st.ResetFlag(state.FLAG_BLOCK)
Logg.DebugCtxf(ctx, "end pre-VM check")
- return r, nil
+ return r, err
}
// Init must be explicitly called before using the Engine instance.
diff --git a/engine/engine_test.go b/engine/engine_test.go
@@ -356,7 +356,7 @@ func preBlock(ctx context.Context, sym string, input []byte) (resource.Result, e
log.Printf("executing preBlock")
return resource.Result{
Content: "None shall pass",
- FlagSet: []uint32{state.FLAG_TERMINATE},
+ FlagSet: []uint32{state.FLAG_BLOCK},
}, nil
}
diff --git a/persist/persist.go b/persist/persist.go
@@ -10,6 +10,7 @@ import (
"git.defalsify.org/vise.git/cache"
)
+// Persister abstracts storage and retrieval of state and cache.
type Persister struct {
State *state.State
Memory *cache.Cache
@@ -17,6 +18,7 @@ type Persister struct {
db db.Db
}
+// NewPersister creates a new Persister instance.
func NewPersister(db db.Db) *Persister {
return &Persister{
db: db,
@@ -24,16 +26,19 @@ func NewPersister(db db.Db) *Persister {
}
}
+// WithSession is a chainable function that sets the current golang context of the persister.
func(p *Persister) WithContext(ctx context.Context) *Persister {
p.ctx = ctx
return p
}
+// WithSession is a chainable function that sets the current session context of the persister.
func(p *Persister) WithSession(sessionId string) *Persister {
p.db.SetSession(sessionId)
return p
}
+
// WithContent sets a current State and Cache object.
//
// This method is normally called before Serialize / Save.
@@ -43,28 +48,39 @@ func(p *Persister) WithContent(st *state.State, ca *cache.Cache) *Persister {
return p
}
-// GetState implements the Persister interface.
+// Invalid checks if the underlying state has been invalidated.
+//
+// An invalid state will cause Save to panic.
+func(p *Persister) Invalid() bool {
+ return p.GetState().Invalid() || p.GetMemory().Invalid()
+}
+
+// GetState returns the state enclosed by the Persister.
func(p *Persister) GetState() *state.State {
return p.State
}
-// GetMemory implements the Persister interface.
+// GetMemory returns the cache (memory) enclosed by the Persister.
func(p *Persister) GetMemory() cache.Memory {
return p.Memory
}
-// Serialize implements the Persister interface.
+// Serialize encodes the state and cache into byte form for storage.
func(p *Persister) Serialize() ([]byte, error) {
return cbor.Marshal(p)
}
-// Deserialize implements the Persister interface.
+// Deserialize decodes the state and cache from storage, and applies them to the persister.
func(p *Persister) Deserialize(b []byte) error {
err := cbor.Unmarshal(b, p)
return err
}
+// Save perists the state and cache to the db.Db backend.
func(p *Persister) Save(key string) error {
+ if p.Invalid() {
+ panic("persister has been invalidated")
+ }
b, err := p.Serialize()
if err != nil {
return err
@@ -73,6 +89,7 @@ func(p *Persister) Save(key string) error {
return p.db.Put(p.ctx, []byte(key), b)
}
+// Load retrieves state and cache from the db.Db backend.
func(p *Persister) Load(key string) error {
p.db.SetPrefix(db.DATATYPE_STATE)
b, err := p.db.Get(p.ctx, []byte(key))
diff --git a/persist/persist_test.go b/persist/persist_test.go
@@ -0,0 +1,77 @@
+package persist
+
+import (
+ "context"
+ "testing"
+
+ "git.defalsify.org/vise.git/db"
+ "git.defalsify.org/vise.git/state"
+ "git.defalsify.org/vise.git/cache"
+)
+
+func TestInvalidateState(t *testing.T) {
+ st := state.NewState(0)
+ ca := cache.NewCache()
+
+ ctx := context.Background()
+ store := db.NewMemDb(ctx)
+ store.Connect(ctx, "")
+ pr := NewPersister(store).WithSession("xyzzy").WithContent(&st, ca)
+ err := pr.Save("foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ st.Invalidate()
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("expected panic")
+ }
+ }()
+ _ = pr.Save("foo")
+}
+
+func TestInvalidateCache(t *testing.T) {
+ st := state.NewState(0)
+ ca := cache.NewCache()
+
+ ctx := context.Background()
+ store := db.NewMemDb(ctx)
+ store.Connect(ctx, "")
+ pr := NewPersister(store).WithSession("xyzzy").WithContent(&st, ca)
+ err := pr.Save("foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ca.Invalidate()
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("expected panic")
+ }
+ }()
+ _ = pr.Save("foo")
+}
+
+func TestInvalidateAll(t *testing.T) {
+ st := state.NewState(0)
+ ca := cache.NewCache()
+
+ ctx := context.Background()
+ store := db.NewMemDb(ctx)
+ store.Connect(ctx, "")
+ pr := NewPersister(store).WithSession("xyzzy").WithContent(&st, ca)
+ err := pr.Save("foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ca.Invalidate()
+ st.Invalidate()
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("expected panic")
+ }
+ }()
+ _ = pr.Save("foo")
+}
diff --git a/state/flag.go b/state/flag.go
@@ -7,13 +7,13 @@ const (
FLAG_WAIT
FLAG_LOADFAIL
FLAG_TERMINATE
- FLAG_RESERVED
+ FLAG_BLOCK
FLAG_LANG
FLAG_USERSTART = 8
)
func IsWriteableFlag(flag uint32) bool {
- if flag > 4 {
+ if flag > 5 {
return true
}
//if flag & FLAG_WRITEABLE > 0 {
diff --git a/state/state.go b/state/state.go
@@ -39,6 +39,7 @@ type State struct {
Language *lang.Language // Language selector for rendering
input []byte // Last input
debug bool // Make string representation more human friendly
+ invalid bool
}
// number of bytes necessary to represent a bitfield of the given size.
@@ -53,6 +54,20 @@ func toByteSize(BitSize uint32) uint8 {
return uint8(BitSize / 8)
}
+// Invalidate marks a state as invalid.
+//
+// An invalid state should not be persisted or propagated
+func(st *State) Invalidate() {
+ st.invalid = true
+}
+
+// Invalid returns true if state is invalid.
+//
+// An invalid state should not be persisted or propagated
+func(st *State) Invalid() bool {
+ return st.invalid
+}
+
//// Retrieve the state of a state flag
//func getFlag(bitIndex uint32, bitField []byte) bool {
// byteIndex := bitIndex / 8