add Protected and Authorize middleware, fix cookie bug, gofmt

This commit is contained in:
Iris Lightshard 2022-05-16 00:15:09 -06:00
parent 483e59e2b2
commit 0e5a81f27b
Signed by: Iris Lightshard
GPG key ID: 3B7FBC22144E6398
9 changed files with 464 additions and 408 deletions

View file

@ -1,48 +1,48 @@
package auth package auth
import ( import (
"time" "net/http"
"net/http" "nilfm.cc/git/quartzgun/cookie"
"nilfm.cc/git/quartzgun/cookie" "time"
) )
type User struct { type User struct {
Name string Name string
Pass string Pass string
Session string Session string
LoginTime time.Time LoginTime time.Time
LastSeen time.Time LastSeen time.Time
Data map[string]interface{} Data map[string]interface{}
} }
type UserStore interface { type UserStore interface {
InitiateSession(user string, password string) (string, error) InitiateSession(user string, password string) (string, error)
ValidateUser(user string, sessionId string) (bool, error) ValidateUser(user string, sessionId string) (bool, error)
EndSession(user string) error EndSession(user string) error
AddUser(user string, password string) error AddUser(user string, password string) error
DeleteUser(user string) error DeleteUser(user string) error
ChangePassword(user string, oldPassword string, newPassword string) error ChangePassword(user string, oldPassword string, newPassword string) error
SetData(user string, key string, value interface{}) error SetData(user string, key string, value interface{}) error
GetData(user string, key string) (interface{}, error) GetData(user string, key string) (interface{}, error)
} }
func Login(user string, password string, userStore UserStore, w http.ResponseWriter, t int) error { func Login(user string, password string, userStore UserStore, w http.ResponseWriter, t int) error {
session, loginErr := userStore.InitiateSession(user, password) session, loginErr := userStore.InitiateSession(user, password)
if loginErr == nil { if loginErr == nil {
cookie.StoreToken("user", user, w, t) cookie.StoreToken("user", user, w, t)
cookie.StoreToken("session", session, w, t) cookie.StoreToken("session", session, w, t)
return nil return nil
} }
return loginErr return loginErr
} }
func Logout(user string, userStore UserStore, w http.ResponseWriter) error { func Logout(user string, userStore UserStore, w http.ResponseWriter) error {
logoutErr := userStore.EndSession(user) logoutErr := userStore.EndSession(user)
if logoutErr == nil { if logoutErr == nil {
cookie.StoreToken("user", "", w, 0) cookie.StoreToken("user", "", w, 0)
cookie.StoreToken("session", "", w, 0) cookie.StoreToken("session", "", w, 0)
return nil return nil
} }
return logoutErr return logoutErr
} }

View file

@ -1,38 +1,38 @@
package cookie package cookie
import ( import (
"net/http" "crypto/rand"
"crypto/rand" "net/http"
"time" "time"
) )
var availableChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ@!.#$_" var availableChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ@!.#$_"
func GenToken(length int) string { func GenToken(length int) string {
ll := len(availableChars) ll := len(availableChars)
b := make([]byte, length) b := make([]byte, length)
rand.Read(b) rand.Read(b)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
b[i] = availableChars[int(b[i])%ll] b[i] = availableChars[int(b[i])%ll]
} }
return string(b) return string(b)
} }
func StoreToken(field string, token string, w http.ResponseWriter, hrs int) { func StoreToken(field string, token string, w http.ResponseWriter, hrs int) {
cookie := http.Cookie{ cookie := http.Cookie{
Name: field, Name: field,
Value: token, Value: token,
Expires: time.Now().Add(time.Duration(hrs) * time.Hour), Expires: time.Now().Add(time.Duration(hrs) * time.Hour),
} }
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }
func GetToken(field string, req *http.Request) (string, error) { func GetToken(field string, req *http.Request) (string, error) {
c, err := req.Cookie(field) c, err := req.Cookie(field)
if err != nil { if err == nil {
return c.Value, nil return c.Value, nil
} else { } else {
return "", err return "", err
} }
} }

6
go.mod
View file

@ -2,8 +2,4 @@ module nilfm.cc/git/quartzgun
go 1.17 go 1.17
require ( require golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
)

7
go.sum
View file

@ -1,2 +1,9 @@
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -1,239 +1,239 @@
package indentalUserDB package indentalUserDB
import ( import (
"time" "errors"
"nilfm.cc/git/quartzgun/cookie" "fmt"
"nilfm.cc/git/quartzgun/auth" "golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/bcrypt" "nilfm.cc/git/quartzgun/auth"
"os" "nilfm.cc/git/quartzgun/cookie"
"strings" "os"
"fmt" "strings"
"errors" "time"
) )
type IndentalUserDB struct { type IndentalUserDB struct {
Users map[string]*auth.User Users map[string]*auth.User
Basis string Basis string
} }
func CreateIndentalUserDB(filePath string) *IndentalUserDB { func CreateIndentalUserDB(filePath string) *IndentalUserDB {
u, err := readDB(filePath) u, err := readDB(filePath)
if err == nil { if err == nil {
uMap := map[string]*auth.User{} uMap := map[string]*auth.User{}
for _, usr := range u { for _, usr := range u {
uMap[usr.Name] = usr uMap[usr.Name] = usr
} }
return &IndentalUserDB{ return &IndentalUserDB{
Users: uMap, Users: uMap,
Basis: filePath, Basis: filePath,
} }
} else { } else {
return &IndentalUserDB{ return &IndentalUserDB{
Users: map[string]*auth.User{}, Users: map[string]*auth.User{},
Basis: filePath, Basis: filePath,
} }
} }
} }
func (self *IndentalUserDB) InitiateSession(user string, password string) (string, error) { func (self *IndentalUserDB) InitiateSession(user string, password string) (string, error) {
if _, exists := self.Users[user]; !exists { if _, exists := self.Users[user]; !exists {
return "", errors.New("User not in DB") return "", errors.New("User not in DB")
} }
if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(password)) != nil { if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(password)) != nil {
return "", errors.New("Incorrect password") return "", errors.New("Incorrect password")
} }
sessionId := cookie.GenToken(64) sessionId := cookie.GenToken(64)
self.Users[user].Session = sessionId self.Users[user].Session = sessionId
self.Users[user].LoginTime = time.Now() self.Users[user].LoginTime = time.Now()
self.Users[user].LastSeen = time.Now() self.Users[user].LastSeen = time.Now()
writeDB(self.Basis, self.Users) writeDB(self.Basis, self.Users)
return sessionId, nil return sessionId, nil
} }
func (self *IndentalUserDB) ValidateUser(user string, sessionId string) (bool, error) { func (self *IndentalUserDB) ValidateUser(user string, sessionId string) (bool, error) {
if _, exists := self.Users[user]; !exists { if _, exists := self.Users[user]; !exists {
return false, errors.New("User not in DB") return false, errors.New("User not in DB")
} }
validated := self.Users[user].Session == sessionId validated := self.Users[user].Session == sessionId
if validated { if validated {
self.Users[user].LastSeen = time.Now() self.Users[user].LastSeen = time.Now()
writeDB(self.Basis, self.Users) writeDB(self.Basis, self.Users)
} }
return validated, nil return validated, nil
} }
func (self *IndentalUserDB) EndSession(user string) error { func (self *IndentalUserDB) EndSession(user string) error {
if _, exists := self.Users[user]; !exists { if _, exists := self.Users[user]; !exists {
return errors.New("User not in DB") return errors.New("User not in DB")
} }
self.Users[user].Session = "" self.Users[user].Session = ""
self.Users[user].LastSeen = time.Now() self.Users[user].LastSeen = time.Now()
writeDB(self.Basis, self.Users) writeDB(self.Basis, self.Users)
return nil return nil
} }
func (self *IndentalUserDB) DeleteUser(user string) error { func (self *IndentalUserDB) DeleteUser(user string) error {
if _, exists := self.Users[user]; !exists { if _, exists := self.Users[user]; !exists {
return errors.New("User not in DB") return errors.New("User not in DB")
} }
delete(self.Users, user) delete(self.Users, user)
writeDB(self.Basis, self.Users) writeDB(self.Basis, self.Users)
return nil return nil
} }
func (self *IndentalUserDB) ChangePassword(user string, password string, oldPassword string) error { func (self *IndentalUserDB) ChangePassword(user string, password string, oldPassword string) error {
if _, exists := self.Users[user]; !exists { if _, exists := self.Users[user]; !exists {
return errors.New("User not in DB") return errors.New("User not in DB")
} }
if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(oldPassword)) != nil { if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(oldPassword)) != nil {
return errors.New("Incorrect password") return errors.New("Incorrect password")
} }
hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10) hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10)
self.Users[user].Pass = string(hash[:]) self.Users[user].Pass = string(hash[:])
writeDB(self.Basis, self.Users) writeDB(self.Basis, self.Users)
return nil return nil
} }
func (self *IndentalUserDB) AddUser(user string, password string) error{ func (self *IndentalUserDB) AddUser(user string, password string) error {
if _, exists := self.Users[user]; exists { if _, exists := self.Users[user]; exists {
return errors.New("User already in DB") return errors.New("User already in DB")
} }
hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10) hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10)
self.Users[user] = &auth.User{ self.Users[user] = &auth.User{
Name: user, Name: user,
Pass: string(hash[:]), Pass: string(hash[:]),
LastSeen: time.UnixMicro(0), LastSeen: time.UnixMicro(0),
LoginTime: time.UnixMicro(0), LoginTime: time.UnixMicro(0),
Session: "", Session: "",
} }
writeDB(self.Basis, self.Users) writeDB(self.Basis, self.Users)
return nil; return nil
} }
func (self *IndentalUserDB) SetData(user string, key string, value interface{}) error { func (self *IndentalUserDB) SetData(user string, key string, value interface{}) error {
if _, exists := self.Users[user]; !exists { if _, exists := self.Users[user]; !exists {
return errors.New("User not in DB") return errors.New("User not in DB")
} }
self.Users[user].Data[key] = value; self.Users[user].Data[key] = value
return nil; return nil
} }
func (self *IndentalUserDB) GetData(user string, key string) (interface{}, error) { func (self *IndentalUserDB) GetData(user string, key string) (interface{}, error) {
if _, usrExists := self.Users[user]; !usrExists { if _, usrExists := self.Users[user]; !usrExists {
return nil, errors.New("User not in DB") return nil, errors.New("User not in DB")
} }
data, exists := self.Users[user].Data[key] data, exists := self.Users[user].Data[key]
if !exists { if !exists {
return nil, errors.New("No data key for user") return nil, errors.New("No data key for user")
} }
return data, nil return data, nil
} }
const timeFmt = "2006-01-02T15:04Z" const timeFmt = "2006-01-02T15:04Z"
func readDB(filePath string) (map[string]*auth.User, error) { func readDB(filePath string) (map[string]*auth.User, error) {
f, err := os.ReadFile(filePath) f, err := os.ReadFile(filePath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fileData := string(f[:]) fileData := string(f[:])
users := map[string]*auth.User{} users := map[string]*auth.User{}
lines := strings.Split(fileData, "\n") lines := strings.Split(fileData, "\n")
indentLevel := "" indentLevel := ""
var name string var name string
var pass string var pass string
var session string var session string
var loginTime time.Time var loginTime time.Time
var lastSeen time.Time var lastSeen time.Time
var data map[string]interface{} var data map[string]interface{}
for _, l := range lines { for _, l := range lines {
if strings.HasPrefix(l, indentLevel) { if strings.HasPrefix(l, indentLevel) {
switch indentLevel { switch indentLevel {
case "": case "":
name = l name = l
indentLevel = "\t" indentLevel = "\t"
case "\t": case "\t":
if strings.Contains(l, ":") { if strings.Contains(l, ":") {
kvp := strings.Split(l, ":") kvp := strings.Split(l, ":")
k := strings.TrimSpace(kvp[0]) k := strings.TrimSpace(kvp[0])
v := strings.TrimSpace(kvp[1]) v := strings.TrimSpace(kvp[1])
switch k { switch k {
case "pass": case "pass":
pass = v pass = v
case "session": case "session":
session = v session = v
case "loginTime": case "loginTime":
loginTime, _ = time.Parse(timeFmt, v) loginTime, _ = time.Parse(timeFmt, v)
case "lastSeen": case "lastSeen":
lastSeen, _ = time.Parse(timeFmt, v) lastSeen, _ = time.Parse(timeFmt, v)
} }
} else { } else {
data = map[string]interface{}{} data = map[string]interface{}{}
indentLevel = "\t\t" indentLevel = "\t\t"
} }
case "\t\t": case "\t\t":
if strings.Contains(l, ":") { if strings.Contains(l, ":") {
kvp := strings.Split(l, ":") kvp := strings.Split(l, ":")
k := strings.TrimSpace(kvp[0]) k := strings.TrimSpace(kvp[0])
v := strings.TrimSpace(kvp[1]) v := strings.TrimSpace(kvp[1])
data[k] = v data[k] = v
} }
} }
} else { } else {
if indentLevel != "\t\t" { if indentLevel != "\t\t" {
panic("Malformed indental file") panic("Malformed indental file")
} else { } else {
users[name] = &auth.User{ users[name] = &auth.User{
Name: name, Name: name,
Pass: pass, Pass: pass,
Session: session, Session: session,
LoginTime: loginTime, LoginTime: loginTime,
LastSeen: lastSeen, LastSeen: lastSeen,
Data: data, Data: data,
} }
indentLevel = "" indentLevel = ""
} }
} }
} }
return users, nil return users, nil
} }
func writeDB(filePath string, users map[string]*auth.User) error { func writeDB(filePath string, users map[string]*auth.User) error {
f, err := os.Create(filePath) f, err := os.Create(filePath)
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer f.Close()
for _, user := range users { for _, user := range users {
f.WriteString(fmt.Sprintf("%s\n\tpass: %s\n\tsession: %s\n\tloginTime: %s\n\tlastSeen: %s\n\tdata\n", f.WriteString(fmt.Sprintf("%s\n\tpass: %s\n\tsession: %s\n\tloginTime: %s\n\tlastSeen: %s\n\tdata\n",
user.Name, user.Name,
user.Pass, user.Pass,
user.Session, user.Session,
user.LoginTime, user.LoginTime,
user.LastSeen)); user.LastSeen))
for k, v := range user.Data { for k, v := range user.Data {
f.WriteString(fmt.Sprintf("\t\t%s: %s\n", k, v)) f.WriteString(fmt.Sprintf("\t\t%s: %s\n", k, v))
} }
f.WriteString("\n") f.WriteString("\n")
} }
f.Sync() f.Sync()
return nil return nil
} }

53
middleware/middleware.go Normal file
View file

@ -0,0 +1,53 @@
package middleware
import (
"context"
"net/http"
"nilfm.cc/git/quartzgun/auth"
"nilfm.cc/git/quartzgun/cookie"
)
func Protected(next http.Handler, userStore auth.UserStore) http.Handler {
handlerFunc := func(w http.ResponseWriter, req *http.Request) {
user, err := cookie.GetToken("user", req)
if err == nil {
session, err := cookie.GetToken("session", req)
if err == nil {
login, err := userStore.ValidateUser(user, session)
if err == nil && login {
next.ServeHTTP(w, req)
return
}
}
}
req.Method = http.MethodGet
http.Redirect(w, req, "/login", http.StatusTemporaryRedirect)
}
return http.HandlerFunc(handlerFunc)
}
func Authorize(next string, userStore auth.UserStore) http.Handler {
handlerFunc := func(w http.ResponseWriter, req *http.Request) {
err := auth.Login(
req.FormValue("user"),
req.FormValue("password"),
userStore,
w,
24*7*52)
if err == nil {
req.Method = http.MethodGet
http.Redirect(w, req, next, http.StatusOK)
} else {
*req = *req.WithContext(
context.WithValue(
req.Context(),
"message",
"Incorrect credentials"))
req.Method = http.MethodGet
http.Redirect(w, req, "/login", http.StatusTemporaryRedirect)
}
}
return http.HandlerFunc(handlerFunc)
}

View file

@ -1,53 +1,53 @@
package main package main
import ( import (
"fmt" "context"
"fmt"
"html/template"
"net/http" "net/http"
"html/template" "nilfm.cc/git/quartzgun/indentalUserDB"
"context" "nilfm.cc/git/quartzgun/renderer"
"nilfm.cc/git/quartzgun/router" "nilfm.cc/git/quartzgun/router"
"nilfm.cc/git/quartzgun/renderer" "testing"
"nilfm.cc/git/quartzgun/indentalUserDB"
"testing"
) )
func AddContent(next http.Handler) http.Handler { func AddContent(next http.Handler) http.Handler {
handlerFunc := func(w http.ResponseWriter, req *http.Request) { handlerFunc := func(w http.ResponseWriter, req *http.Request) {
if !req.Form.Has("Content") { if !req.Form.Has("Content") {
req.Form.Add("Content", "Yesssssss") req.Form.Add("Content", "Yesssssss")
} }
next.ServeHTTP(w, req) next.ServeHTTP(w, req)
} }
return http.HandlerFunc(handlerFunc) return http.HandlerFunc(handlerFunc)
} }
func ApiSomething(next http.Handler) http.Handler { func ApiSomething(next http.Handler) http.Handler {
handlerFunc := func(w http.ResponseWriter, req *http.Request) { handlerFunc := func(w http.ResponseWriter, req *http.Request) {
*req = *req.WithContext(context.WithValue(req.Context(), "apiData", "something")) *req = *req.WithContext(context.WithValue(req.Context(), "apiData", "something"))
next.ServeHTTP(w, req) next.ServeHTTP(w, req)
} }
return http.HandlerFunc(handlerFunc) return http.HandlerFunc(handlerFunc)
} }
func TestMain(m *testing.M){ func TestMain(m *testing.M) {
udb := indentalUserDB.CreateIndentalUserDB("testData/userDB.ndtl") udb := indentalUserDB.CreateIndentalUserDB("testData/userDB.ndtl")
udb.AddUser("nilix", "questing") udb.AddUser("nilix", "questing")
sesh, _ := udb.InitiateSession("nilix", "questing") sesh, _ := udb.InitiateSession("nilix", "questing")
fmt.Printf("%s // %s\n", sesh, sesh) fmt.Printf("%s // %s\n", sesh, sesh)
rtr := &router.Router{ rtr := &router.Router{
StaticPaths: map[string]string{ StaticPaths: map[string]string{
"/static": "testData/static", "/static": "testData/static",
}, },
Fallback: *template.Must(template.ParseFiles("testData/templates/error.html", "testData/templates/footer.html")), Fallback: *template.Must(template.ParseFiles("testData/templates/error.html", "testData/templates/footer.html")),
} }
rtr.Get("/", AddContent(renderer.Template("testData/templates/test.html"))) rtr.Get("/", AddContent(renderer.Template("testData/templates/test.html")))
rtr.Get("/json", ApiSomething(renderer.JSON("apiData"))) rtr.Get("/json", ApiSomething(renderer.JSON("apiData")))
rtr.Get(`/thing/(?P<Thing>\w+)`, renderer.Template("testData/templates/paramTest.html")) rtr.Get(`/thing/(?P<Thing>\w+)`, renderer.Template("testData/templates/paramTest.html"))
http.ListenAndServe(":8080", rtr) http.ListenAndServe(":8080", rtr)
} }

View file

@ -1,48 +1,48 @@
package renderer package renderer
import ( import (
"net/http" "encoding/json"
"html/template" "encoding/xml"
"encoding/json" "html/template"
"encoding/xml" "net/http"
) )
func Template(t ...string) http.Handler { func Template(t ...string) http.Handler {
tmpl := template.Must(template.ParseFiles(t...)) tmpl := template.Must(template.ParseFiles(t...))
handlerFunc := func(w http.ResponseWriter, req *http.Request) { handlerFunc := func(w http.ResponseWriter, req *http.Request) {
tmpl.Execute(w, req) tmpl.Execute(w, req)
} }
return http.HandlerFunc(handlerFunc) return http.HandlerFunc(handlerFunc)
} }
func JSON(key string) http.Handler { func JSON(key string) http.Handler {
handlerFunc := func(w http.ResponseWriter, req *http.Request) { handlerFunc := func(w http.ResponseWriter, req *http.Request) {
apiData := req.Context().Value(key) apiData := req.Context().Value(key)
data, err := json.Marshal(apiData) data, err := json.Marshal(apiData)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write(data) w.Write(data)
} }
return http.HandlerFunc(handlerFunc) return http.HandlerFunc(handlerFunc)
} }
func XML(key string) http.Handler { func XML(key string) http.Handler {
handlerFunc := func(w http.ResponseWriter, req *http.Request) { handlerFunc := func(w http.ResponseWriter, req *http.Request) {
apiData := req.Context().Value(key) apiData := req.Context().Value(key)
data, err := xml.MarshalIndent(apiData, "", " ") data, err := xml.MarshalIndent(apiData, "", " ")
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }
w.Header().Set("Content-Type", "application/xml") w.Header().Set("Content-Type", "application/xml")
w.Write(data) w.Write(data)
} }
return http.HandlerFunc(handlerFunc) return http.HandlerFunc(handlerFunc)
} }

View file

@ -1,127 +1,127 @@
package router package router
import ( import (
"net/http" "context"
"html/template" "errors"
"regexp" "html/template"
"log" "log"
"strconv" "net/http"
"strings" "os"
"path" "path"
"os" "regexp"
"errors" "strconv"
"context" "strings"
) )
type Router struct { type Router struct {
/* This is the template for error pages */ /* This is the template for error pages */
Fallback template.Template Fallback template.Template
/* Routes are only filled by using the appropriate methods. */ /* Routes are only filled by using the appropriate methods. */
routes []Route routes []Route
/* StaticPaths can be filled from outside when constructing the Router. /* StaticPaths can be filled from outside when constructing the Router.
* key = uri * key = uri
* value = file path * value = file path
*/ */
StaticPaths map[string]string StaticPaths map[string]string
} }
type Route struct { type Route struct {
path *regexp.Regexp path *regexp.Regexp
handlerMap map[string]http.Handler handlerMap map[string]http.Handler
} }
func (self *Router) Get(path string, h http.Handler) { func (self *Router) Get(path string, h http.Handler) {
self.AddRoute("GET", path, h) self.AddRoute("GET", path, h)
} }
func (self *Router) Post(path string, h http.Handler) { func (self *Router) Post(path string, h http.Handler) {
self.AddRoute("POST", path, h) self.AddRoute("POST", path, h)
} }
func (self *Router) Put(path string, h http.Handler) { func (self *Router) Put(path string, h http.Handler) {
self.AddRoute("PUT", path, h) self.AddRoute("PUT", path, h)
} }
func (self *Router) Delete(path string, h http.Handler) { func (self *Router) Delete(path string, h http.Handler) {
self.AddRoute("DELETE", path, h) self.AddRoute("DELETE", path, h)
} }
func (self *Router) AddRoute(method string, path string, h http.Handler) { func (self *Router) AddRoute(method string, path string, h http.Handler) {
exactPath := regexp.MustCompile("^" + path + "$") exactPath := regexp.MustCompile("^" + path + "$")
/* If the route already exists, try to add this method to the ServerTask map. */ /* If the route already exists, try to add this method to the ServerTask map. */
for _, r := range self.routes { for _, r := range self.routes {
if r.path == exactPath { if r.path == exactPath {
r.handlerMap[method] = h r.handlerMap[method] = h
return return
} }
} }
/* Otherwise add a new route */ /* Otherwise add a new route */
self.routes = append(self.routes, Route{ self.routes = append(self.routes, Route{
path: exactPath, path: exactPath,
handlerMap: map[string]http.Handler{method: h}, handlerMap: map[string]http.Handler{method: h},
}) })
} }
func (self *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (self *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
/* Show the 500 error page if we panic */ /* Show the 500 error page if we panic */
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Println("ERROR:", r) log.Println("ERROR:", r)
self.ErrorPage(w, req, 500, "There was an error on the server.") self.ErrorPage(w, req, 500, "There was an error on the server.")
} }
}() }()
/* If the request matches any our StaticPaths, try to serve a file. */ /* If the request matches any our StaticPaths, try to serve a file. */
for uri, dir := range self.StaticPaths { for uri, dir := range self.StaticPaths {
if req.Method == "GET" && strings.HasPrefix(req.URL.Path, uri) { if req.Method == "GET" && strings.HasPrefix(req.URL.Path, uri) {
restOfUri := strings.TrimPrefix(req.URL.Path, uri) restOfUri := strings.TrimPrefix(req.URL.Path, uri)
p := path.Join(dir, restOfUri) p := path.Join(dir, restOfUri)
p = path.Clean(p) p = path.Clean(p)
/* If the file exists, try to serve it. */ /* If the file exists, try to serve it. */
info, err := os.Stat(p); info, err := os.Stat(p)
if err == nil && !info.IsDir() { if err == nil && !info.IsDir() {
http.ServeFile(w, req, p) http.ServeFile(w, req, p)
/* Handle the common errors */ /* Handle the common errors */
} else if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrExist) { } else if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrExist) {
self.ErrorPage(w, req, 404, "The requested file does not exist") self.ErrorPage(w, req, 404, "The requested file does not exist")
} else if errors.Is(err, os.ErrPermission) || info.IsDir() { } else if errors.Is(err, os.ErrPermission) || info.IsDir() {
self.ErrorPage(w, req, 403, "Access forbidden") self.ErrorPage(w, req, 403, "Access forbidden")
/* If it's some weird error, serve a 500. */ /* If it's some weird error, serve a 500. */
} else { } else {
self.ErrorPage(w, req, 500, "Internal server error") self.ErrorPage(w, req, 500, "Internal server error")
} }
return return
} }
} }
/* Otherwise, this is a normal route */ /* Otherwise, this is a normal route */
for _, r := range self.routes { for _, r := range self.routes {
/* Pull the params out of the regex; /* Pull the params out of the regex;
* If the path doesn't match the regex, params will be nil. * If the path doesn't match the regex, params will be nil.
*/ */
params := r.Match(req) params := r.Match(req)
if params == nil { if params == nil {
continue continue
} }
for method, handler := range r.handlerMap { for method, handler := range r.handlerMap {
if method == req.Method { if method == req.Method {
/* Parse the form and add the params to the context */ /* Parse the form and add the params to the context */
req.ParseForm() req.ParseForm()
ProcessParams(req, params) ProcessParams(req, params)
/* handle the request! */ /* handle the request! */
handler.ServeHTTP(w, req); handler.ServeHTTP(w, req)
return return
} }
} }
} }
self.ErrorPage(w, req, 404, "The page you requested does not exist!") self.ErrorPage(w, req, 404, "The page you requested does not exist!")
} }
/******************* /*******************
@ -129,31 +129,31 @@ func (self *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
*******************/ *******************/
func ProcessParams(req *http.Request, params map[string]string) { func ProcessParams(req *http.Request, params map[string]string) {
*req = *req.WithContext(context.WithValue(req.Context(), "params", params)) *req = *req.WithContext(context.WithValue(req.Context(), "params", params))
} }
func (self *Route) Match(r *http.Request) map[string]string { func (self *Route) Match(r *http.Request) map[string]string {
match := self.path.FindStringSubmatch(r.URL.Path) match := self.path.FindStringSubmatch(r.URL.Path)
if match == nil { if match == nil {
return nil return nil
} }
params := map[string]string{} params := map[string]string{}
groupNames := self.path.SubexpNames() groupNames := self.path.SubexpNames()
for i, group := range match { for i, group := range match {
params[groupNames[i]] = group params[groupNames[i]] = group
} }
return params return params
} }
func (self *Router) ErrorPage(w http.ResponseWriter, req *http.Request, code int, errMsg string) { func (self *Router) ErrorPage(w http.ResponseWriter, req *http.Request, code int, errMsg string) {
w.WriteHeader(code) w.WriteHeader(code)
params := map[string]string{ params := map[string]string{
"ErrorCode": strconv.Itoa(code), "ErrorCode": strconv.Itoa(code),
"ErrorMessage": errMsg, "ErrorMessage": errMsg,
} }
ProcessParams(req, params) ProcessParams(req, params)
self.Fallback.Execute(w, req) self.Fallback.Execute(w, req)
} }