diff --git a/main.go b/main.go index 9e9420b..9095d50 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,13 @@ package main import ( + "fmt" + "io" "log" "net/http" + "os" + "path/filepath" + "strconv" "strings" "time" ) @@ -19,8 +24,9 @@ type PageProvider interface { } type RootHandler struct { - Sessions SessionProvider - Pages PageProvider + Sessions SessionProvider + Pages PageProvider + StaticDir string } type AdminHandler struct { @@ -47,25 +53,57 @@ func (h *RootHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // fall back to serving out of the static directory, but: // 1. prevent the generated indexes from rendering if strings.HasSuffix(r.URL.Path, "/") { - http.NotFound(w, r) + h.ErrorHandle(404, w) return } // 2. prevent hidden paths from rendering for _, seg := range strings.Split(r.URL.Path, "/") { if strings.HasPrefix(seg, ".") { - http.NotFound(w, r) + h.ErrorHandle(404, w) return } } + // 3. catch files that would 404 and serve our own 404 page + if !staticFileExists(filepath.Join(h.StaticDir, r.URL.Path)) { + h.ErrorHandle(404, w) + return + } // finally, use the built-in fileserver to serve - fs := http.FileServer(http.Dir("./static")) + fs := http.FileServer(http.Dir(h.StaticDir)) fs.ServeHTTP(w, r) } +func (h *RootHandler) ErrorHandle(status int, w http.ResponseWriter) { + f, err := os.Open(filepath.Join(h.StaticDir, strconv.Itoa(status)+".html")) + if err == nil { + w.WriteHeader(status) + _, err = io.Copy(w, f) + if err != nil { + fmt.Fprintf(w, "Internal Server Error while loading %d page\n", status) + } + return + } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(404) + fmt.Fprintf(w, "%d\n", status) +} + +func staticFileExists(name string) bool { + f, err := os.Open(name) + if err != nil { + return false + } + defer f.Close() + _, err = f.Stat() + return err == nil +} + func main() { handler := &RootHandler{ - Sessions: &Sessions{}, - Pages: &Index{}, + Sessions: &Sessions{}, + Pages: &Index{}, + StaticDir: "./static", } handler.Pages.Save("foo", &Page{ Contents: []byte("foobar"),