Files
nebula/internal/auth/middleware.go
2026-03-10 16:26:48 +08:00

105 lines
2.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
"nebula/internal/api/response"
"strings"
"github.com/gin-gonic/gin"
)
// JWTMiddleware JWT 认证中间件
func JWTMiddleware(jwtService *JWTService) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
response.Fail(c, 401, "missing authorization header")
c.Abort()
return
}
// 解析 Bearer token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
response.Fail(c, 401, "invalid authorization header format")
c.Abort()
return
}
tokenString := parts[1]
// 验证 token
claims, err := jwtService.ValidateToken(tokenString)
if err != nil {
response.Fail(c, 401, "invalid or expired token")
c.Abort()
return
}
// 将用户信息保存到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("role", claims.Role)
c.Next()
}
}
// OptionalJWTMiddleware 可选的 JWT 中间件token 无效不报错)
func OptionalJWTMiddleware(jwtService *JWTService) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
tokenString := parts[1]
if claims, err := jwtService.ValidateToken(tokenString); err == nil {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("role", claims.Role)
}
}
}
c.Next()
}
}
// AdminMiddleware 管理员权限中间件(需要先使用 JWTMiddleware
func AdminMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists || role != "admin" {
response.Fail(c, 403, "admin access required")
c.Abort()
return
}
c.Next()
}
}
// GetCurrentUserID 从上下文获取当前用户 ID
func GetCurrentUserID(c *gin.Context) (string, bool) {
userID, exists := c.Get("user_id")
if !exists {
return "", false
}
return userID.(string), true
}
// GetCurrentUsername 从上下文获取当前用户名
func GetCurrentUsername(c *gin.Context) (string, bool) {
username, exists := c.Get("username")
if !exists {
return "", false
}
return username.(string), true
}
// GetCurrentUserRole 从上下文获取当前用户角色
func GetCurrentUserRole(c *gin.Context) (string, bool) {
role, exists := c.Get("role")
if !exists {
return "", false
}
return role.(string), true
}