diff --git a/model.go b/model.go index 6c19930..6a69798 100644 --- a/model.go +++ b/model.go @@ -39,15 +39,19 @@ func (i *Index) Page(key string) (*Page, error) { curr, rest, found := strings.Cut(key, "/") page, pageok := i.Pages[curr] child, childok := i.Children[curr] + // no trailing slash and doesn't exist if !found && !pageok { return &Page{}, ErrPageNotFound } + // trailing slash, exists as page if rest == "" && pageok { return &page, nil } + // neither page nor child if !childok { return &Page{}, ErrPageNotFound } + // recurse return (&child).Page(rest) } @@ -64,22 +68,27 @@ func (i *Index) Save(key string, page *Page) error { if key[0] == '/' { // strip leading slash key = key[1:] } + // init maps if not yet created if i.Pages == nil { i.Pages = map[string]Page{} } if i.Children == nil { i.Children = map[string]Index{} } + + // save as page curr, rest, _ := strings.Cut(key, "/") if rest == "" { i.Pages[curr] = *page - } else { - children := i.Children[curr] - err := (&children).Save(rest, page) - if err != nil { - return err - } - i.Children[curr] = children + return nil } + + // recurse and save in child + children := i.Children[curr] + err := (&children).Save(rest, page) + if err != nil { + return err + } + i.Children[curr] = children return nil } diff --git a/model_test.go b/model_test.go index 0abe80f..7241132 100644 --- a/model_test.go +++ b/model_test.go @@ -155,7 +155,17 @@ func TestPage(t *testing.T) { func TestSave(t *testing.T) { i := &Index{} - err := i.Save("foo", &Page{Title: "fooroot"}) + err := i.Save("/", &Page{Title: "invalid"}) + if err == nil { + t.Logf("expected err, received nil") + } + _, err = i.Page("/") + if err == nil { + t.Logf("expected err, received nil") + t.Fail() + } + + err = i.Save("foo", &Page{Title: "fooroot"}) if err != nil { t.Logf("expected no err, received %v", err) } @@ -169,7 +179,7 @@ func TestSave(t *testing.T) { t.Fail() } - err = i.Save("foo/", &Page{Title: "fooroot2"}) + err = i.Save("/foo/", &Page{Title: "fooroot2"}) if err != nil { t.Logf("expected no err, received %v", err) }