package auth import ( "context" "mealprep/database" "net/http" ) type contextKey string const ( userIDKey contextKey = "userID" ) // RequireAuth is middleware that checks if user is authenticated func RequireAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Get session cookie cookie, err := r.Cookie("session_token") if err != nil { http.Redirect(w, r, "/login", http.StatusSeeOther) return } // Validate session session, err := database.GetSession(cookie.Value) if err != nil { // Invalid or expired session http.SetCookie(w, &http.Cookie{ Name: "session_token", Value: "", MaxAge: -1, HttpOnly: true, Path: "/", }) http.Redirect(w, r, "/login", http.StatusSeeOther) return } // Add user ID to request context ctx := context.WithValue(r.Context(), userIDKey, session.UserID) next.ServeHTTP(w, r.WithContext(ctx)) } } // GetUserID retrieves the user ID from the request context func GetUserID(r *http.Request) int { userID, ok := r.Context().Value(userIDKey).(int) if !ok { return 0 } return userID } // RedirectIfAuthenticated redirects to home if user is already logged in func RedirectIfAuthenticated(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_token") if err == nil { // Check if session is valid _, err := database.GetSession(cookie.Value) if err == nil { // Valid session, redirect to home http.Redirect(w, r, "/", http.StatusSeeOther) return } } next.ServeHTTP(w, r) } }