85 lines
1.6 KiB
Go
85 lines
1.6 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"math"
|
|
"time"
|
|
)
|
|
|
|
var ErrInvalidSession = fmt.Errorf("session not found")
|
|
|
|
type Sess struct {
|
|
User User
|
|
expr time.Time
|
|
}
|
|
|
|
type Sessions struct {
|
|
sessions map[string]Sess
|
|
}
|
|
|
|
func (s *Sessions) Create(user User, expr time.Duration) (string, error) {
|
|
if s.sessions == nil {
|
|
s.sessions = map[string]Sess{}
|
|
}
|
|
key := randomStr(24)
|
|
s.sessions[key] = Sess{
|
|
User: user,
|
|
expr: time.Now().Add(expr),
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (s *Sessions) Get(key string) (User, error) {
|
|
if s.sessions == nil {
|
|
s.sessions = map[string]Sess{}
|
|
return User{}, ErrInvalidSession
|
|
}
|
|
sess, ok := s.sessions[key]
|
|
if !ok {
|
|
return User{}, ErrInvalidSession
|
|
}
|
|
if sess.expr.Before(time.Now()) {
|
|
delete(s.sessions, key)
|
|
return User{}, ErrInvalidSession
|
|
}
|
|
return sess.User, nil
|
|
}
|
|
|
|
func (s *Sessions) Refresh(key string, user User, expr time.Duration) error {
|
|
if s.sessions == nil {
|
|
s.sessions = map[string]Sess{}
|
|
return ErrInvalidSession
|
|
}
|
|
sess, ok := s.sessions[key]
|
|
if !ok {
|
|
return ErrInvalidSession
|
|
}
|
|
if sess.expr.Before(time.Now()) {
|
|
delete(s.sessions, key)
|
|
return ErrInvalidSession
|
|
}
|
|
s.sessions[key] = Sess{
|
|
User: user,
|
|
expr: time.Now().Add(expr),
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Sessions) Destroy(key string) {
|
|
if s.sessions == nil {
|
|
s.sessions = map[string]Sess{}
|
|
return
|
|
}
|
|
delete(s.sessions, key)
|
|
}
|
|
|
|
// taken from https://stackoverflow.com/a/55860599
|
|
func randomStr(l int) string {
|
|
buff := make([]byte, int(math.Ceil(float64(l)/float64(1.33333333333))))
|
|
rand.Read(buff)
|
|
str := base64.RawURLEncoding.EncodeToString(buff)
|
|
return str[:l] // strip 1 extra character we get from odd length results
|
|
}
|