diff --git a/auth/auth.go b/auth/auth.go index 00e8e69..ccb5573 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,48 +1,48 @@ package auth import ( - "time" - "net/http" - "nilfm.cc/git/quartzgun/cookie" + "net/http" + "nilfm.cc/git/quartzgun/cookie" + "time" ) type User struct { - Name string - Pass string - Session string - LoginTime time.Time - LastSeen time.Time + Name string + Pass string + Session string + LoginTime time.Time + LastSeen time.Time - Data map[string]interface{} + Data map[string]interface{} } type UserStore interface { - InitiateSession(user string, password string) (string, error) - ValidateUser(user string, sessionId string) (bool, error) - EndSession(user string) error - AddUser(user string, password string) error - DeleteUser(user string) error - ChangePassword(user string, oldPassword string, newPassword string) error - SetData(user string, key string, value interface{}) error - GetData(user string, key string) (interface{}, error) + InitiateSession(user string, password string) (string, error) + ValidateUser(user string, sessionId string) (bool, error) + EndSession(user string) error + AddUser(user string, password string) error + DeleteUser(user string) error + ChangePassword(user string, oldPassword string, newPassword string) error + SetData(user string, key string, value interface{}) error + GetData(user string, key string) (interface{}, error) } func Login(user string, password string, userStore UserStore, w http.ResponseWriter, t int) error { - session, loginErr := userStore.InitiateSession(user, password) - if loginErr == nil { - cookie.StoreToken("user", user, w, t) - cookie.StoreToken("session", session, w, t) - return nil - } - return loginErr + session, loginErr := userStore.InitiateSession(user, password) + if loginErr == nil { + cookie.StoreToken("user", user, w, t) + cookie.StoreToken("session", session, w, t) + return nil + } + return loginErr } func Logout(user string, userStore UserStore, w http.ResponseWriter) error { - logoutErr := userStore.EndSession(user) - if logoutErr == nil { - cookie.StoreToken("user", "", w, 0) - cookie.StoreToken("session", "", w, 0) - return nil - } - return logoutErr + logoutErr := userStore.EndSession(user) + if logoutErr == nil { + cookie.StoreToken("user", "", w, 0) + cookie.StoreToken("session", "", w, 0) + return nil + } + return logoutErr } diff --git a/cookie/cookie.go b/cookie/cookie.go index 935ae46..eda305c 100644 --- a/cookie/cookie.go +++ b/cookie/cookie.go @@ -1,38 +1,38 @@ package cookie import ( - "net/http" - "crypto/rand" - "time" + "crypto/rand" + "net/http" + "time" ) var availableChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ@!.#$_" func GenToken(length int) string { - ll := len(availableChars) - b := make([]byte, length) - rand.Read(b) - for i := 0; i < length; i++ { - b[i] = availableChars[int(b[i])%ll] - } - return string(b) + ll := len(availableChars) + b := make([]byte, length) + rand.Read(b) + for i := 0; i < length; i++ { + b[i] = availableChars[int(b[i])%ll] + } + return string(b) } func StoreToken(field string, token string, w http.ResponseWriter, hrs int) { - cookie := http.Cookie{ - Name: field, - Value: token, - Expires: time.Now().Add(time.Duration(hrs) * time.Hour), - } + cookie := http.Cookie{ + Name: field, + Value: token, + Expires: time.Now().Add(time.Duration(hrs) * time.Hour), + } - http.SetCookie(w, &cookie) + http.SetCookie(w, &cookie) } func GetToken(field string, req *http.Request) (string, error) { - c, err := req.Cookie(field) - if err != nil { - return c.Value, nil - } else { - return "", err - } + c, err := req.Cookie(field) + if err == nil { + return c.Value, nil + } else { + return "", err + } } diff --git a/go.mod b/go.mod index 801ed98..c264153 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,4 @@ module nilfm.cc/git/quartzgun go 1.17 -require ( - golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 -) - - +require golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 diff --git a/go.sum b/go.sum index ab658c0..d0b757e 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,9 @@ golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/indentalUserDB/indentalUserDB.go b/indentalUserDB/indentalUserDB.go index 3ab3587..f79d314 100644 --- a/indentalUserDB/indentalUserDB.go +++ b/indentalUserDB/indentalUserDB.go @@ -1,239 +1,239 @@ package indentalUserDB import ( - "time" - "nilfm.cc/git/quartzgun/cookie" - "nilfm.cc/git/quartzgun/auth" - "golang.org/x/crypto/bcrypt" - "os" - "strings" - "fmt" - "errors" + "errors" + "fmt" + "golang.org/x/crypto/bcrypt" + "nilfm.cc/git/quartzgun/auth" + "nilfm.cc/git/quartzgun/cookie" + "os" + "strings" + "time" ) type IndentalUserDB struct { - Users map[string]*auth.User - Basis string + Users map[string]*auth.User + Basis string } func CreateIndentalUserDB(filePath string) *IndentalUserDB { - u, err := readDB(filePath) - if err == nil { - uMap := map[string]*auth.User{} - for _, usr := range u { - uMap[usr.Name] = usr - } - return &IndentalUserDB{ - Users: uMap, - Basis: filePath, - } - } else { - return &IndentalUserDB{ - Users: map[string]*auth.User{}, - Basis: filePath, - } - } + u, err := readDB(filePath) + if err == nil { + uMap := map[string]*auth.User{} + for _, usr := range u { + uMap[usr.Name] = usr + } + return &IndentalUserDB{ + Users: uMap, + Basis: filePath, + } + } else { + return &IndentalUserDB{ + Users: map[string]*auth.User{}, + Basis: filePath, + } + } } func (self *IndentalUserDB) InitiateSession(user string, password string) (string, error) { - if _, exists := self.Users[user]; !exists { - return "", errors.New("User not in DB") - } - if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(password)) != nil { - return "", errors.New("Incorrect password") - } - sessionId := cookie.GenToken(64) - self.Users[user].Session = sessionId - self.Users[user].LoginTime = time.Now() - self.Users[user].LastSeen = time.Now() - writeDB(self.Basis, self.Users) - return sessionId, nil + if _, exists := self.Users[user]; !exists { + return "", errors.New("User not in DB") + } + if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(password)) != nil { + return "", errors.New("Incorrect password") + } + sessionId := cookie.GenToken(64) + self.Users[user].Session = sessionId + self.Users[user].LoginTime = time.Now() + self.Users[user].LastSeen = time.Now() + writeDB(self.Basis, self.Users) + return sessionId, nil } func (self *IndentalUserDB) ValidateUser(user string, sessionId string) (bool, error) { - if _, exists := self.Users[user]; !exists { - return false, errors.New("User not in DB") - } + if _, exists := self.Users[user]; !exists { + return false, errors.New("User not in DB") + } - validated := self.Users[user].Session == sessionId - if validated { - self.Users[user].LastSeen = time.Now() - writeDB(self.Basis, self.Users) - } + validated := self.Users[user].Session == sessionId + if validated { + self.Users[user].LastSeen = time.Now() + writeDB(self.Basis, self.Users) + } - return validated, nil + return validated, nil } func (self *IndentalUserDB) EndSession(user string) error { - if _, exists := self.Users[user]; !exists { - return errors.New("User not in DB") - } + if _, exists := self.Users[user]; !exists { + return errors.New("User not in DB") + } - self.Users[user].Session = "" - self.Users[user].LastSeen = time.Now() - writeDB(self.Basis, self.Users) - return nil + self.Users[user].Session = "" + self.Users[user].LastSeen = time.Now() + writeDB(self.Basis, self.Users) + return nil } func (self *IndentalUserDB) DeleteUser(user string) error { - if _, exists := self.Users[user]; !exists { - return errors.New("User not in DB") - } + if _, exists := self.Users[user]; !exists { + return errors.New("User not in DB") + } - delete(self.Users, user) - writeDB(self.Basis, self.Users) - return nil + delete(self.Users, user) + writeDB(self.Basis, self.Users) + return nil } func (self *IndentalUserDB) ChangePassword(user string, password string, oldPassword string) error { - if _, exists := self.Users[user]; !exists { - return errors.New("User not in DB") - } - if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(oldPassword)) != nil { - return errors.New("Incorrect password") - } + if _, exists := self.Users[user]; !exists { + return errors.New("User not in DB") + } + if bcrypt.CompareHashAndPassword([]byte(self.Users[user].Pass), []byte(oldPassword)) != nil { + return errors.New("Incorrect password") + } - hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10) - self.Users[user].Pass = string(hash[:]) - writeDB(self.Basis, self.Users) - return nil + hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10) + self.Users[user].Pass = string(hash[:]) + writeDB(self.Basis, self.Users) + return nil } -func (self *IndentalUserDB) AddUser(user string, password string) error{ - if _, exists := self.Users[user]; exists { - return errors.New("User already in DB") - } +func (self *IndentalUserDB) AddUser(user string, password string) error { + if _, exists := self.Users[user]; exists { + return errors.New("User already in DB") + } - hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10) + hash, _ := bcrypt.GenerateFromPassword([]byte(password), 10) - self.Users[user] = &auth.User{ - Name: user, - Pass: string(hash[:]), - LastSeen: time.UnixMicro(0), - LoginTime: time.UnixMicro(0), - Session: "", - } - writeDB(self.Basis, self.Users) - return nil; + self.Users[user] = &auth.User{ + Name: user, + Pass: string(hash[:]), + LastSeen: time.UnixMicro(0), + LoginTime: time.UnixMicro(0), + Session: "", + } + writeDB(self.Basis, self.Users) + return nil } func (self *IndentalUserDB) SetData(user string, key string, value interface{}) error { - if _, exists := self.Users[user]; !exists { - return errors.New("User not in DB") - } + if _, exists := self.Users[user]; !exists { + return errors.New("User not in DB") + } - self.Users[user].Data[key] = value; - return nil; + self.Users[user].Data[key] = value + return nil } func (self *IndentalUserDB) GetData(user string, key string) (interface{}, error) { - if _, usrExists := self.Users[user]; !usrExists { - return nil, errors.New("User not in DB") - } - data, exists := self.Users[user].Data[key] - if !exists { - return nil, errors.New("No data key for user") - } + if _, usrExists := self.Users[user]; !usrExists { + return nil, errors.New("User not in DB") + } + data, exists := self.Users[user].Data[key] + if !exists { + return nil, errors.New("No data key for user") + } - return data, nil + return data, nil } const timeFmt = "2006-01-02T15:04Z" func readDB(filePath string) (map[string]*auth.User, error) { - f, err := os.ReadFile(filePath) - if err != nil { - return nil, err - } + f, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } - fileData := string(f[:]) - users := map[string]*auth.User{} + fileData := string(f[:]) + users := map[string]*auth.User{} - lines := strings.Split(fileData, "\n") + lines := strings.Split(fileData, "\n") - indentLevel := "" + indentLevel := "" - var name string - var pass string - var session string - var loginTime time.Time - var lastSeen time.Time - var data map[string]interface{} + var name string + var pass string + var session string + var loginTime time.Time + var lastSeen time.Time + var data map[string]interface{} - for _, l := range lines { - if strings.HasPrefix(l, indentLevel) { - switch indentLevel { - case "": - name = l - indentLevel = "\t" + for _, l := range lines { + if strings.HasPrefix(l, indentLevel) { + switch indentLevel { + case "": + name = l + indentLevel = "\t" - case "\t": - if strings.Contains(l, ":") { - kvp := strings.Split(l, ":") - k := strings.TrimSpace(kvp[0]) - v := strings.TrimSpace(kvp[1]) - switch k { - case "pass": - pass = v - case "session": - session = v - case "loginTime": - loginTime, _ = time.Parse(timeFmt, v) - case "lastSeen": - lastSeen, _ = time.Parse(timeFmt, v) - } - } else { - data = map[string]interface{}{} - indentLevel = "\t\t" - } + case "\t": + if strings.Contains(l, ":") { + kvp := strings.Split(l, ":") + k := strings.TrimSpace(kvp[0]) + v := strings.TrimSpace(kvp[1]) + switch k { + case "pass": + pass = v + case "session": + session = v + case "loginTime": + loginTime, _ = time.Parse(timeFmt, v) + case "lastSeen": + lastSeen, _ = time.Parse(timeFmt, v) + } + } else { + data = map[string]interface{}{} + indentLevel = "\t\t" + } - case "\t\t": - if strings.Contains(l, ":") { - kvp := strings.Split(l, ":") - k := strings.TrimSpace(kvp[0]) - v := strings.TrimSpace(kvp[1]) - data[k] = v - } - } - } else { - if indentLevel != "\t\t" { - panic("Malformed indental file") - } else { - users[name] = &auth.User{ - Name: name, - Pass: pass, - Session: session, - LoginTime: loginTime, - LastSeen: lastSeen, - Data: data, - } - indentLevel = "" - } - } - } - return users, nil + case "\t\t": + if strings.Contains(l, ":") { + kvp := strings.Split(l, ":") + k := strings.TrimSpace(kvp[0]) + v := strings.TrimSpace(kvp[1]) + data[k] = v + } + } + } else { + if indentLevel != "\t\t" { + panic("Malformed indental file") + } else { + users[name] = &auth.User{ + Name: name, + Pass: pass, + Session: session, + LoginTime: loginTime, + LastSeen: lastSeen, + Data: data, + } + indentLevel = "" + } + } + } + return users, nil } func writeDB(filePath string, users map[string]*auth.User) error { - f, err := os.Create(filePath) - if err != nil { - return err - } + f, err := os.Create(filePath) + if err != nil { + return err + } - defer f.Close() + defer f.Close() - for _, user := range users { - f.WriteString(fmt.Sprintf("%s\n\tpass: %s\n\tsession: %s\n\tloginTime: %s\n\tlastSeen: %s\n\tdata\n", - user.Name, - user.Pass, - user.Session, - user.LoginTime, - user.LastSeen)); - for k, v := range user.Data { - f.WriteString(fmt.Sprintf("\t\t%s: %s\n", k, v)) - } - f.WriteString("\n") - } - f.Sync() - return nil + for _, user := range users { + f.WriteString(fmt.Sprintf("%s\n\tpass: %s\n\tsession: %s\n\tloginTime: %s\n\tlastSeen: %s\n\tdata\n", + user.Name, + user.Pass, + user.Session, + user.LoginTime, + user.LastSeen)) + for k, v := range user.Data { + f.WriteString(fmt.Sprintf("\t\t%s: %s\n", k, v)) + } + f.WriteString("\n") + } + f.Sync() + return nil } diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..138ed71 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "context" + "net/http" + "nilfm.cc/git/quartzgun/auth" + "nilfm.cc/git/quartzgun/cookie" +) + +func Protected(next http.Handler, userStore auth.UserStore) http.Handler { + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + user, err := cookie.GetToken("user", req) + if err == nil { + session, err := cookie.GetToken("session", req) + if err == nil { + login, err := userStore.ValidateUser(user, session) + if err == nil && login { + next.ServeHTTP(w, req) + return + } + } + } + req.Method = http.MethodGet + http.Redirect(w, req, "/login", http.StatusTemporaryRedirect) + } + + return http.HandlerFunc(handlerFunc) +} + +func Authorize(next string, userStore auth.UserStore) http.Handler { + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + err := auth.Login( + req.FormValue("user"), + req.FormValue("password"), + userStore, + w, + 24*7*52) + if err == nil { + req.Method = http.MethodGet + http.Redirect(w, req, next, http.StatusOK) + } else { + *req = *req.WithContext( + context.WithValue( + req.Context(), + "message", + "Incorrect credentials")) + req.Method = http.MethodGet + http.Redirect(w, req, "/login", http.StatusTemporaryRedirect) + } + } + + return http.HandlerFunc(handlerFunc) +} diff --git a/quartzgun_test.go b/quartzgun_test.go index b21af82..c325926 100644 --- a/quartzgun_test.go +++ b/quartzgun_test.go @@ -1,53 +1,53 @@ package main import ( - "fmt" + "context" + "fmt" + "html/template" "net/http" - "html/template" - "context" - "nilfm.cc/git/quartzgun/router" - "nilfm.cc/git/quartzgun/renderer" - "nilfm.cc/git/quartzgun/indentalUserDB" - "testing" + "nilfm.cc/git/quartzgun/indentalUserDB" + "nilfm.cc/git/quartzgun/renderer" + "nilfm.cc/git/quartzgun/router" + "testing" ) func AddContent(next http.Handler) http.Handler { - handlerFunc := func(w http.ResponseWriter, req *http.Request) { - if !req.Form.Has("Content") { - req.Form.Add("Content", "Yesssssss") - } - next.ServeHTTP(w, req) - } - return http.HandlerFunc(handlerFunc) + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + if !req.Form.Has("Content") { + req.Form.Add("Content", "Yesssssss") + } + next.ServeHTTP(w, req) + } + return http.HandlerFunc(handlerFunc) } func ApiSomething(next http.Handler) http.Handler { - handlerFunc := func(w http.ResponseWriter, req *http.Request) { - *req = *req.WithContext(context.WithValue(req.Context(), "apiData", "something")) - next.ServeHTTP(w, req) - } + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + *req = *req.WithContext(context.WithValue(req.Context(), "apiData", "something")) + next.ServeHTTP(w, req) + } - return http.HandlerFunc(handlerFunc) + return http.HandlerFunc(handlerFunc) } -func TestMain(m *testing.M){ - udb := indentalUserDB.CreateIndentalUserDB("testData/userDB.ndtl") - udb.AddUser("nilix", "questing") - sesh, _ := udb.InitiateSession("nilix", "questing") +func TestMain(m *testing.M) { + udb := indentalUserDB.CreateIndentalUserDB("testData/userDB.ndtl") + udb.AddUser("nilix", "questing") + sesh, _ := udb.InitiateSession("nilix", "questing") - fmt.Printf("%s // %s\n", sesh, sesh) - rtr := &router.Router{ - StaticPaths: map[string]string{ - "/static": "testData/static", - }, - Fallback: *template.Must(template.ParseFiles("testData/templates/error.html", "testData/templates/footer.html")), - } + fmt.Printf("%s // %s\n", sesh, sesh) + rtr := &router.Router{ + StaticPaths: map[string]string{ + "/static": "testData/static", + }, + Fallback: *template.Must(template.ParseFiles("testData/templates/error.html", "testData/templates/footer.html")), + } - rtr.Get("/", AddContent(renderer.Template("testData/templates/test.html"))) + rtr.Get("/", AddContent(renderer.Template("testData/templates/test.html"))) - rtr.Get("/json", ApiSomething(renderer.JSON("apiData"))) + rtr.Get("/json", ApiSomething(renderer.JSON("apiData"))) - rtr.Get(`/thing/(?P\w+)`, renderer.Template("testData/templates/paramTest.html")) + rtr.Get(`/thing/(?P\w+)`, renderer.Template("testData/templates/paramTest.html")) - http.ListenAndServe(":8080", rtr) + http.ListenAndServe(":8080", rtr) } diff --git a/renderer/renderer.go b/renderer/renderer.go index 031986e..00288a8 100644 --- a/renderer/renderer.go +++ b/renderer/renderer.go @@ -1,48 +1,48 @@ package renderer import ( - "net/http" - "html/template" - "encoding/json" - "encoding/xml" + "encoding/json" + "encoding/xml" + "html/template" + "net/http" ) func Template(t ...string) http.Handler { - tmpl := template.Must(template.ParseFiles(t...)) + tmpl := template.Must(template.ParseFiles(t...)) - handlerFunc := func(w http.ResponseWriter, req *http.Request) { - tmpl.Execute(w, req) - } + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + tmpl.Execute(w, req) + } - return http.HandlerFunc(handlerFunc) + return http.HandlerFunc(handlerFunc) } func JSON(key string) http.Handler { - handlerFunc := func(w http.ResponseWriter, req *http.Request) { - apiData := req.Context().Value(key) + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + apiData := req.Context().Value(key) - data, err := json.Marshal(apiData) - if err != nil { - panic(err.Error()) - } - w.Header().Set("Content-Type", "application/json") - w.Write(data) - } + data, err := json.Marshal(apiData) + if err != nil { + panic(err.Error()) + } + w.Header().Set("Content-Type", "application/json") + w.Write(data) + } - return http.HandlerFunc(handlerFunc) + return http.HandlerFunc(handlerFunc) } func XML(key string) http.Handler { - handlerFunc := func(w http.ResponseWriter, req *http.Request) { - apiData := req.Context().Value(key) + handlerFunc := func(w http.ResponseWriter, req *http.Request) { + apiData := req.Context().Value(key) - data, err := xml.MarshalIndent(apiData, "", " ") - if err != nil { - panic(err.Error()) - } - w.Header().Set("Content-Type", "application/xml") - w.Write(data) - } + data, err := xml.MarshalIndent(apiData, "", " ") + if err != nil { + panic(err.Error()) + } + w.Header().Set("Content-Type", "application/xml") + w.Write(data) + } - return http.HandlerFunc(handlerFunc) + return http.HandlerFunc(handlerFunc) } diff --git a/router/router.go b/router/router.go index 7bdaf8d..2d9ea3c 100644 --- a/router/router.go +++ b/router/router.go @@ -1,127 +1,127 @@ package router import ( - "net/http" - "html/template" - "regexp" - "log" - "strconv" - "strings" - "path" - "os" - "errors" - "context" + "context" + "errors" + "html/template" + "log" + "net/http" + "os" + "path" + "regexp" + "strconv" + "strings" ) type Router struct { - /* This is the template for error pages */ - Fallback template.Template - /* Routes are only filled by using the appropriate methods. */ - routes []Route - /* StaticPaths can be filled from outside when constructing the Router. - * key = uri - * value = file path - */ - StaticPaths map[string]string + /* This is the template for error pages */ + Fallback template.Template + /* Routes are only filled by using the appropriate methods. */ + routes []Route + /* StaticPaths can be filled from outside when constructing the Router. + * key = uri + * value = file path + */ + StaticPaths map[string]string } type Route struct { - path *regexp.Regexp - handlerMap map[string]http.Handler + path *regexp.Regexp + handlerMap map[string]http.Handler } func (self *Router) Get(path string, h http.Handler) { - self.AddRoute("GET", path, h) + self.AddRoute("GET", path, h) } func (self *Router) Post(path string, h http.Handler) { - self.AddRoute("POST", path, h) + self.AddRoute("POST", path, h) } -func (self *Router) Put(path string, h http.Handler) { - self.AddRoute("PUT", path, h) +func (self *Router) Put(path string, h http.Handler) { + self.AddRoute("PUT", path, h) } func (self *Router) Delete(path string, h http.Handler) { - self.AddRoute("DELETE", path, h) + self.AddRoute("DELETE", path, h) } func (self *Router) AddRoute(method string, path string, h http.Handler) { - exactPath := regexp.MustCompile("^" + path + "$") + exactPath := regexp.MustCompile("^" + path + "$") - /* If the route already exists, try to add this method to the ServerTask map. */ - for _, r := range self.routes { - if r.path == exactPath { - r.handlerMap[method] = h - return - } - } + /* If the route already exists, try to add this method to the ServerTask map. */ + for _, r := range self.routes { + if r.path == exactPath { + r.handlerMap[method] = h + return + } + } - /* Otherwise add a new route */ - self.routes = append(self.routes, Route{ - path: exactPath, - handlerMap: map[string]http.Handler{method: h}, - }) + /* Otherwise add a new route */ + self.routes = append(self.routes, Route{ + path: exactPath, + handlerMap: map[string]http.Handler{method: h}, + }) } func (self *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - /* Show the 500 error page if we panic */ - defer func() { - if r := recover(); r != nil { - log.Println("ERROR:", r) - self.ErrorPage(w, req, 500, "There was an error on the server.") - } - }() + /* Show the 500 error page if we panic */ + defer func() { + if r := recover(); r != nil { + log.Println("ERROR:", r) + self.ErrorPage(w, req, 500, "There was an error on the server.") + } + }() - /* If the request matches any our StaticPaths, try to serve a file. */ - for uri, dir := range self.StaticPaths { - if req.Method == "GET" && strings.HasPrefix(req.URL.Path, uri) { - restOfUri := strings.TrimPrefix(req.URL.Path, uri) - p := path.Join(dir, restOfUri) - p = path.Clean(p) + /* If the request matches any our StaticPaths, try to serve a file. */ + for uri, dir := range self.StaticPaths { + if req.Method == "GET" && strings.HasPrefix(req.URL.Path, uri) { + restOfUri := strings.TrimPrefix(req.URL.Path, uri) + p := path.Join(dir, restOfUri) + p = path.Clean(p) - /* If the file exists, try to serve it. */ - info, err := os.Stat(p); - if err == nil && !info.IsDir() { - http.ServeFile(w, req, p) - /* Handle the common errors */ - } else if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrExist) { - self.ErrorPage(w, req, 404, "The requested file does not exist") - } else if errors.Is(err, os.ErrPermission) || info.IsDir() { - self.ErrorPage(w, req, 403, "Access forbidden") - /* If it's some weird error, serve a 500. */ - } else { - self.ErrorPage(w, req, 500, "Internal server error") - } + /* If the file exists, try to serve it. */ + info, err := os.Stat(p) + if err == nil && !info.IsDir() { + http.ServeFile(w, req, p) + /* Handle the common errors */ + } else if errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrExist) { + self.ErrorPage(w, req, 404, "The requested file does not exist") + } else if errors.Is(err, os.ErrPermission) || info.IsDir() { + self.ErrorPage(w, req, 403, "Access forbidden") + /* If it's some weird error, serve a 500. */ + } else { + self.ErrorPage(w, req, 500, "Internal server error") + } - return - } - } + return + } + } - /* Otherwise, this is a normal route */ - for _, r := range self.routes { + /* Otherwise, this is a normal route */ + for _, r := range self.routes { - /* Pull the params out of the regex; - * If the path doesn't match the regex, params will be nil. - */ - params := r.Match(req) - if params == nil { - continue - } - for method, handler := range r.handlerMap { - if method == req.Method { - /* Parse the form and add the params to the context */ - req.ParseForm() - ProcessParams(req, params) - /* handle the request! */ - handler.ServeHTTP(w, req); - return - } - } - } - self.ErrorPage(w, req, 404, "The page you requested does not exist!") + /* Pull the params out of the regex; + * If the path doesn't match the regex, params will be nil. + */ + params := r.Match(req) + if params == nil { + continue + } + for method, handler := range r.handlerMap { + if method == req.Method { + /* Parse the form and add the params to the context */ + req.ParseForm() + ProcessParams(req, params) + /* handle the request! */ + handler.ServeHTTP(w, req) + return + } + } + } + self.ErrorPage(w, req, 404, "The page you requested does not exist!") } /******************* @@ -129,31 +129,31 @@ func (self *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { *******************/ func ProcessParams(req *http.Request, params map[string]string) { - *req = *req.WithContext(context.WithValue(req.Context(), "params", params)) + *req = *req.WithContext(context.WithValue(req.Context(), "params", params)) } func (self *Route) Match(r *http.Request) map[string]string { - match := self.path.FindStringSubmatch(r.URL.Path) - if match == nil { - return nil - } + match := self.path.FindStringSubmatch(r.URL.Path) + if match == nil { + return nil + } - params := map[string]string{} - groupNames := self.path.SubexpNames() + params := map[string]string{} + groupNames := self.path.SubexpNames() - for i, group := range match { - params[groupNames[i]] = group - } + for i, group := range match { + params[groupNames[i]] = group + } - return params + return params } func (self *Router) ErrorPage(w http.ResponseWriter, req *http.Request, code int, errMsg string) { - w.WriteHeader(code) - params := map[string]string{ - "ErrorCode": strconv.Itoa(code), - "ErrorMessage": errMsg, - } - ProcessParams(req, params) - self.Fallback.Execute(w, req) + w.WriteHeader(code) + params := map[string]string{ + "ErrorCode": strconv.Itoa(code), + "ErrorMessage": errMsg, + } + ProcessParams(req, params) + self.Fallback.Execute(w, req) }