diff --git a/datastructures.go b/datastructures.go index d6fcb8b..5c8abb9 100644 --- a/datastructures.go +++ b/datastructures.go @@ -38,7 +38,20 @@ func (sbt *StringyBTreeNode) Right() TreeNode { } func (sbt *StringyBTreeNode) Print(w io.Writer) { - fmt.Printf("%+v\n", sbt) + l, lok := sbt.left.(*StringyBTreeNode) + r, rok := sbt.right.(*StringyBTreeNode) + fmt.Fprintf(w, "k: %s v: %v, ", sbt.k, sbt.v) + if !rok { + fmt.Fprint(w, "r: _, ") + } else { + fmt.Fprintf(w, "r: %s, ", r.k) + } + if !lok { + fmt.Fprint(w, "l: _, ") + } else { + fmt.Fprintf(w, "l: %s, ", l.k) + } + fmt.Fprint(w, "\n") if sbt.left != nil { sbt.left.Print(w) } @@ -71,18 +84,39 @@ func (sbt *StringyBTreeNode) Insert(k string, v interface{}) *StringyBTreeNode { } func (sbt *StringyBTreeNode) Get(k string) interface{} { + r := sbt.getNode(k) + if r == nil { + return nil + } + return r.v +} + +func (sbt *StringyBTreeNode) getNode(k string) *StringyBTreeNode { if sbt.k == k { - return sbt.v + return sbt } - if sbt.k > k { - if sbt.left != nil { - return sbt.left.(*StringyBTreeNode).Get(k) - } + if sbt.left != nil { + return sbt.left.(*StringyBTreeNode).getNode(k) } - if sbt.k < k { - if sbt.right != nil { - return sbt.right.(*StringyBTreeNode).Get(k) - } + if sbt.right != nil { + return sbt.right.(*StringyBTreeNode).getNode(k) } return nil // not found } + +func (sbt *StringyBTreeNode) rightmost(parent *StringyBTreeNode) (*StringyBTreeNode, *StringyBTreeNode) { + if sbt.right == nil { + return parent, sbt + } + return sbt.right.(*StringyBTreeNode).rightmost(sbt) +} + +// swap the rightmost node to the node we're removing +func (sbt *StringyBTreeNode) Delete(k string) *StringyBTreeNode { + parent, rightmost := sbt.rightmost(nil) + delete := sbt.getNode(k) + delete.k = rightmost.k + delete.v = rightmost.v + parent.right = nil + return sbt +} diff --git a/datastructures_test.go b/datastructures_test.go index 020ec40..2b7ccac 100644 --- a/datastructures_test.go +++ b/datastructures_test.go @@ -1,6 +1,7 @@ package main import ( + "os" "testing" ) @@ -27,6 +28,39 @@ func TestBuildingTree(t *testing.T) { } } +func TestDeletingFromTree(t *testing.T) { + tree := NewStringyBTree("foo", 1).Insert("bar", 2).Insert("baz", 3).Insert("quuz", 4) + tree.Delete("bar") + if tree.Get("foo") != 1 { + t.Fail() + } + if tree.Get("bar") != nil { + t.Fail() + } + if tree.Get("baz") != 3 { + t.Fail() + } + if tree.Get("quuz") != 4 { + t.Fail() + } + tree.Print(os.Stdout) + + tree = NewStringyBTree("foo", 1).Insert("bar", 2).Insert("baz", 3).Insert("quuz", 4) + tree.Delete("foo") + if tree.Get("foo") != nil { + t.Fail() + } + if tree.Get("bar") != 2 { + t.Fail() + } + if tree.Get("baz") != 3 { + t.Fail() + } + if tree.Get("quuz") != 4 { + t.Fail() + } +} + var r TreeNode func BenchmarkInserts(b *testing.B) {