| @@ -10,6 +10,7 @@ import ( | |||||
| ) | ) | ||||
| type Claims struct { | type Claims struct { | ||||
| ID uint `json:"id"` | |||||
| Username string `json:"username"` | Username string `json:"username"` | ||||
| Privileges uint `json:"privileges"` | Privileges uint `json:"privileges"` | ||||
| jwt.RegisteredClaims | jwt.RegisteredClaims | ||||
| @@ -66,9 +66,10 @@ type Contribution struct { | |||||
| type User struct { | type User struct { | ||||
| ModelBase | ModelBase | ||||
| Username string `json:"username"` | |||||
| Password string `json:"password"` | |||||
| Privileges uint `json:"admin"` | |||||
| Username string `json:"username"` | |||||
| Password string `json:"password"` | |||||
| Privileges uint `json:"admin"` | |||||
| LastLogin *time.Time `json:"lastLogin"` | |||||
| } | } | ||||
| var Db *gorm.DB | var Db *gorm.DB | ||||
| @@ -0,0 +1,45 @@ | |||||
| package endpoints | |||||
| import ( | |||||
| "encoding/json" | |||||
| "net/http" | |||||
| "time" | |||||
| . "github.com/imosed/signet/data" | |||||
| ) | |||||
| type ChangePasswordRequest struct { | |||||
| UserID uint `json:"userID"` | |||||
| Password string `json:"password"` | |||||
| } | |||||
| func ChangePassword(w http.ResponseWriter, r *http.Request) { | |||||
| var req ChangePasswordRequest | |||||
| err := json.NewDecoder(r.Body).Decode(&req) | |||||
| if err != nil { | |||||
| panic("Could not decode body") | |||||
| } | |||||
| var user User | |||||
| Db.Table("users").First(&user, req.UserID) | |||||
| var password string | |||||
| password, err = GetHashedPassword(req.Password) | |||||
| if err != nil { | |||||
| panic("Could not get password") | |||||
| } | |||||
| if user.LastLogin == nil { | |||||
| Db.Table("users").Where("id = ?", req.UserID).Updates(map[string]interface{}{"last_login": time.Now(), "password": password}) | |||||
| } else { | |||||
| Db.Table("users").Where("id = ?", req.UserID).Update("password = ?", password) | |||||
| } | |||||
| var resp SuccessResponse | |||||
| resp.Success = true | |||||
| err = json.NewEncoder(w).Encode(resp) | |||||
| if err != nil { | |||||
| panic("Could not deliver response") | |||||
| } | |||||
| } | |||||
| @@ -10,14 +10,15 @@ import ( | |||||
| ) | ) | ||||
| type EscalatePrivilegesRequest struct { | type EscalatePrivilegesRequest struct { | ||||
| Username string | |||||
| UserID uint `json:"userID"` | |||||
| Privileges uint `json:"privileges"` | |||||
| } | } | ||||
| func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
| func ChangePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
| var req EscalatePrivilegesRequest | var req EscalatePrivilegesRequest | ||||
| err := json.NewDecoder(r.Body).Decode(&req) | err := json.NewDecoder(r.Body).Decode(&req) | ||||
| if err != nil { | if err != nil { | ||||
| log.Error().Err(err).Msg("Could not decode body in EscalatePrivileges call") | |||||
| log.Error().Err(err).Msg("Could not decode body in ChangePrivileges call") | |||||
| return | return | ||||
| } | } | ||||
| @@ -28,8 +29,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
| claims, err = auth.GetUserClaims(r) | claims, err = auth.GetUserClaims(r) | ||||
| if claims.Privileges < 2 { | if claims.Privileges < 2 { | ||||
| Db.Table("users").Where("username = ?", req.Username).Find(&user) | |||||
| if user.Privileges < 2 { | |||||
| Db.Table("users").Where("id = ?", req.UserID).Find(&user) | |||||
| if req.Privileges == SuperUser { | |||||
| resp.Success = false | resp.Success = false | ||||
| err = json.NewEncoder(w).Encode(resp) | err = json.NewEncoder(w).Encode(resp) | ||||
| @@ -39,7 +40,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
| return | return | ||||
| } | } | ||||
| user.Privileges = AdminPlus | |||||
| user.Privileges = req.Privileges | |||||
| Db.Save(user) | |||||
| resp.Success = true | resp.Success = true | ||||
| } else { | } else { | ||||
| resp.Success = false | resp.Success = false | ||||
| @@ -13,7 +13,8 @@ import ( | |||||
| ) | ) | ||||
| type LoginResponse struct { | type LoginResponse struct { | ||||
| Token *string `json:"token"` | |||||
| Token *string `json:"token"` | |||||
| LastLogin *time.Time `json:"lastLogin"` | |||||
| } | } | ||||
| func Login(w http.ResponseWriter, r *http.Request) { | func Login(w http.ResponseWriter, r *http.Request) { | ||||
| @@ -24,14 +25,12 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
| return | return | ||||
| } | } | ||||
| var userData struct { | |||||
| ID uint | |||||
| Password string | |||||
| Privileges uint | |||||
| } | |||||
| var userData User | |||||
| var loginTime = time.Now() | |||||
| var resp LoginResponse | var resp LoginResponse | ||||
| Db.Table("users").Select("id, password, privileges"). | |||||
| Db.Table("users").Select("id, password, privileges, last_login"). | |||||
| Where("username = ?", req.Username).First(&userData) | Where("username = ?", req.Username).First(&userData) | ||||
| var passwordMatches bool | var passwordMatches bool | ||||
| passwordMatches, err = ComparePasswordAndHash(req.Password, userData.Password) | passwordMatches, err = ComparePasswordAndHash(req.Password, userData.Password) | ||||
| @@ -40,7 +39,6 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
| return | return | ||||
| } | } | ||||
| if !passwordMatches { | if !passwordMatches { | ||||
| resp.Token = nil | |||||
| err = json.NewEncoder(w).Encode(resp) | err = json.NewEncoder(w).Encode(resp) | ||||
| if err != nil { | if err != nil { | ||||
| log.Error().Err(err).Msg("Failed to deliver failed login attempt response") | log.Error().Err(err).Msg("Failed to deliver failed login attempt response") | ||||
| @@ -49,10 +47,11 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
| } | } | ||||
| token := jwt.NewWithClaims(jwt.SigningMethodHS256, &auth.Claims{ | token := jwt.NewWithClaims(jwt.SigningMethodHS256, &auth.Claims{ | ||||
| ID: userData.ID, | |||||
| Username: req.Username, | Username: req.Username, | ||||
| Privileges: userData.Privileges, | Privileges: userData.Privileges, | ||||
| RegisteredClaims: jwt.RegisteredClaims{ | RegisteredClaims: jwt.RegisteredClaims{ | ||||
| ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Hour)), | |||||
| ExpiresAt: jwt.NewNumericDate(loginTime.Add(10 * time.Hour)), | |||||
| }, | }, | ||||
| }) | }) | ||||
| @@ -63,6 +62,12 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
| return | return | ||||
| } | } | ||||
| resp.Token = &tokenString | resp.Token = &tokenString | ||||
| resp.LastLogin = userData.LastLogin | |||||
| if userData.LastLogin != nil { | |||||
| // need to set this after the user changes their password | |||||
| Db.Table("users").Where("id = ?", userData.ID).Update("last_login", loginTime) | |||||
| } | |||||
| err = json.NewEncoder(w).Encode(resp) | err = json.NewEncoder(w).Encode(resp) | ||||
| if err != nil { | if err != nil { | ||||
| @@ -124,6 +124,17 @@ func determinePrivileges() uint { | |||||
| } | } | ||||
| } | } | ||||
| func GetHashedPassword(password string) (encodedHash string, err error) { | |||||
| hash, err := GenerateHash(password, &Params{ | |||||
| Memory: uint32(viper.GetInt("hashing.memory")), | |||||
| Iterations: uint32(viper.GetInt("hashing.iterations")), | |||||
| Parallelism: uint8(viper.GetInt("hashing.parallelism")), | |||||
| SaltLength: uint32(viper.GetInt("hashing.saltLength")), | |||||
| KeyLength: uint32(viper.GetInt("hashing.keyLength")), | |||||
| }) | |||||
| return hash, err | |||||
| } | |||||
| func Register(w http.ResponseWriter, r *http.Request) { | func Register(w http.ResponseWriter, r *http.Request) { | ||||
| var req AuthenticationRequest | var req AuthenticationRequest | ||||
| err := json.NewDecoder(r.Body).Decode(&req) | err := json.NewDecoder(r.Body).Decode(&req) | ||||
| @@ -144,13 +155,7 @@ func Register(w http.ResponseWriter, r *http.Request) { | |||||
| } | } | ||||
| if noUsersRegistered() || claims.Privileges <= AdminPlus { | if noUsersRegistered() || claims.Privileges <= AdminPlus { | ||||
| hash, err := GenerateHash(req.Password, &Params{ | |||||
| Memory: uint32(viper.GetInt("hashing.memory")), | |||||
| Iterations: uint32(viper.GetInt("hashing.iterations")), | |||||
| Parallelism: uint8(viper.GetInt("hashing.parallelism")), | |||||
| SaltLength: uint32(viper.GetInt("hashing.saltLength")), | |||||
| KeyLength: uint32(viper.GetInt("hashing.keyLength")), | |||||
| }) | |||||
| hash, err := GetHashedPassword(req.Password) | |||||
| if err != nil { | if err != nil { | ||||
| log.Error().Err(err).Msg("Could not generate hash for registration") | log.Error().Err(err).Msg("Could not generate hash for registration") | ||||
| return | return | ||||
| @@ -48,8 +48,9 @@ func main() { | |||||
| router.HandleFunc("/ContributorStream", endpoints.ContributorStream) | router.HandleFunc("/ContributorStream", endpoints.ContributorStream) | ||||
| router.HandleFunc("/Login", endpoints.Login) | router.HandleFunc("/Login", endpoints.Login) | ||||
| router.HandleFunc("/Register", endpoints.Register) | router.HandleFunc("/Register", endpoints.Register) | ||||
| router.HandleFunc("/ChangePassword", endpoints.ChangePassword) | |||||
| router.HandleFunc("/NearlyCompleteFunds", endpoints.NearlyCompleteFunds) | router.HandleFunc("/NearlyCompleteFunds", endpoints.NearlyCompleteFunds) | ||||
| router.HandleFunc("/EscalatePrivileges", endpoints.EscalatePrivileges) | |||||
| router.HandleFunc("/ChangePrivileges", endpoints.ChangePrivileges) | |||||
| router.HandleFunc("/UsersExist", endpoints.UsersExist) | router.HandleFunc("/UsersExist", endpoints.UsersExist) | ||||
| port := viper.GetInt("app.port") | port := viper.GetInt("app.port") | ||||