diff --git a/certificates/certificates.go b/certificates/certificates.go new file mode 100644 index 0000000..3ea4cd2 --- /dev/null +++ b/certificates/certificates.go @@ -0,0 +1,97 @@ +package certificates + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "strings" + "time" +) + +func GenerateKeyPair(host string) (*rsa.PrivateKey, []byte, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return &rsa.PrivateKey{}, []byte{}, err + } + keyUsage := x509.KeyUsageDigitalSignature + notBefore := time.Now() + notAfter := notBefore.Add(time.Hour * 24 * 365 * 5) // five yearss + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return &rsa.PrivateKey{}, []byte{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Castor Server"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + hosts := strings.Split(host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + bytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return &rsa.PrivateKey{}, []byte{}, err + } + return priv, bytes, nil +} + +func TestCertificateExists(certname, keyname string) error { + _, err := os.Stat(certname) + if err != nil { + return fmt.Errorf("Could not open %s; %w", certname, err) + } + _, err = os.Stat(keyname) + if err != nil { + return fmt.Errorf("Could not open %s; %w", keyname, err) + } + return nil +} + +func WriteCertsToFile(certname, keyname string, cert []byte, privkey *rsa.PrivateKey) error { + certOut, err := os.Create(certname) + if err != nil { + return fmt.Errorf("Failed to open %s for writing: %w", certname, err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: cert}); err != nil { + return fmt.Errorf("Failed to write data to %s: %w", certname, err) + } + if err := certOut.Close(); err != nil { + return fmt.Errorf("Error closing %s: %w", certname, err) + } + keyOut, err := os.OpenFile(keyname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("Failed to open %s for writing: %w", keyname, err) + } + privBytes, err := x509.MarshalPKCS8PrivateKey(privkey) + if err != nil { + return fmt.Errorf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + return fmt.Errorf("Failed to write data to %s: %v", keyname, err) + } + if err := keyOut.Close(); err != nil { + return fmt.Errorf("Error closing %s: %v", keyname, err) + } + return nil +} diff --git a/main.go b/main.go index 79e8e94..bf18f7e 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,8 @@ import ( "net/url" "os" "path/filepath" + + "git.yetaga.in/alazyreader/castor/certificates" ) var responseCodes = map[string]int{ @@ -37,35 +39,22 @@ var responseCodes = map[string]int{ "CERTIFICATENOTVALID": 62, } -// interface -type geminiRequest interface { - GetURL() *url.URL -} - -// implementation -type request struct { +type geminiRequest struct { url *url.URL } -func (r request) GetURL() *url.URL { +func (r *geminiRequest) GetURL() *url.URL { return r.url } -// interface -type geminiResponse interface { - WriteStatus(code int, meta string) (int, error) - Write([]byte) (int, error) -} - -// implementation -type response struct { +type geminiResponse struct { statusSent bool status int meta string connection net.Conn } -func (w *response) WriteStatus(code int, meta string) (int, error) { +func (w *geminiResponse) WriteStatus(code int, meta string) (int, error) { if w.statusSent { return 0, fmt.Errorf("Cannot set status after start of response") } @@ -75,7 +64,7 @@ func (w *response) WriteStatus(code int, meta string) (int, error) { return fmt.Fprintf(w.connection, "%d %s\r\n", code, meta) } -func (w *response) Write(b []byte) (int, error) { +func (w *geminiResponse) Write(b []byte) (int, error) { if !w.statusSent { // this can't guess text/gemini, of course. guessedType := http.DetectContentType(b) @@ -84,24 +73,23 @@ func (w *response) Write(b []byte) (int, error) { return w.connection.Write(b) } -// interface type geminiHandler interface { - Handle(geminiResponse, geminiRequest) + Handle(*geminiResponse, *geminiRequest) } -type geminiHandlerFunc func(geminiResponse, geminiRequest) +type geminiHandlerFunc func(*geminiResponse, *geminiRequest) // Handle calls f(w, r). -func (f geminiHandlerFunc) Handle(w geminiResponse, r geminiRequest) { +func (f geminiHandlerFunc) Handle(w *geminiResponse, r *geminiRequest) { f(w, r) } -// implementations +// handlers type staticGeminiHandler struct { StaticString string } -func (h staticGeminiHandler) Handle(w geminiResponse, r geminiRequest) { +func (h staticGeminiHandler) Handle(w *geminiResponse, r *geminiRequest) { w.Write([]byte(h.StaticString)) } @@ -124,7 +112,7 @@ func genIndex(folder, rel string) ([]byte, error) { return ret.Bytes(), nil } -func (h fsGeminiHandler) Handle(w geminiResponse, r geminiRequest) { +func (h fsGeminiHandler) Handle(w *geminiResponse, r *geminiRequest) { // Clean, then join; can't escape the defined root req := filepath.Join(h.root, filepath.Clean(r.GetURL().Path)) @@ -135,10 +123,10 @@ func (h fsGeminiHandler) Handle(w geminiResponse, r geminiRequest) { } if sourceFileStat.IsDir() { - sourceFileStat, err = os.Stat(filepath.Join(req, "index.gemini")) + sourceFileStat, err = os.Stat(filepath.Join(req, "index.gmi")) if err == nil && sourceFileStat.Mode().IsRegular() { - // if it's a directory, transparently insert the index.gemini check - req = filepath.Join(req, "index.gemini") + // if it's a directory, transparently insert the index.gmi check + req = filepath.Join(req, "index.gmi") } else if h.DirectoryListing { b, err := genIndex(req, filepath.Clean(r.GetURL().Path)) if err != nil { @@ -169,7 +157,7 @@ func (h fsGeminiHandler) Handle(w geminiResponse, r geminiRequest) { } func recoveryHandler(next geminiHandler) geminiHandler { - return geminiHandlerFunc(func(w geminiResponse, r geminiRequest) { + return geminiHandlerFunc(func(w *geminiResponse, r *geminiRequest) { defer func() { err := recover() if err != nil { @@ -183,19 +171,6 @@ func recoveryHandler(next geminiHandler) geminiHandler { }) } -// handler for general http queries (fallthrough for certmagic) -type genericHTTPHandler struct { - StaticString string -} - -func (h *genericHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if h.StaticString != "" { - w.Write([]byte(h.StaticString)) - return - } - w.Write([]byte("This is the default http response for the castor server. Try connecting over the gemini protocol instead.\n")) -} - func handleConnection(log Logger, conn net.Conn, h geminiHandler) { defer conn.Close() scanner := bufio.NewScanner(conn) @@ -206,27 +181,40 @@ func handleConnection(log Logger, conn net.Conn, h geminiHandler) { if err != nil { log.Info(err) } - w := response{ + w := &geminiResponse{ connection: conn, } - r := request{ + r := &geminiRequest{ url: u, } - recoveryHandler(h).Handle(&w, r) + recoveryHandler(h).Handle(w, r) } func main() { log := NewLogger(true) - err := mime.AddExtensionType(".gemini", "text/gemini") - err2 := mime.AddExtensionType(".gmi", "text/gemini") - if err != nil || err2 != nil { - log.Info("Could not add text/gemini to mime-type database;", err) + mime.AddExtensionType(".gemini", "text/gemini") + mime.AddExtensionType(".gmi", "text/gemini") + + err := certificates.TestCertificateExists("./cert.pem", "./key.pem") + var cer tls.Certificate + if err != nil { + log.Info("Generating new certificate...") + key, cert, err := certificates.GenerateKeyPair("localhost") + if err != nil { + log.Info("error generating certificates", err) + return + } + err = certificates.WriteCertsToFile("./cert.pem", "./key.pem", cert, key) + if err != nil { + log.Info("error saving certificates", err) + return + } } - cer, err := tls.LoadX509KeyPair("./cert.pem", "./key.pem") + cer, err = tls.LoadX509KeyPair("./cert.pem", "./key.pem") if err != nil { - log.Info("", err) + log.Info("error loading certificates", err) return }