first commit
This commit is contained in:
121
internal/middleware/auth.go
Normal file
121
internal/middleware/auth.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"homework-manager/internal/service"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const UserIDKey = "user_id"
|
||||
const UserRoleKey = "user_role"
|
||||
const UserNameKey = "user_name"
|
||||
|
||||
func AuthRequired(authService *service.AuthService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
userID := session.Get(UserIDKey)
|
||||
|
||||
if userID == nil {
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
user, err := authService.GetUserByID(userID.(uint))
|
||||
if err != nil {
|
||||
session.Clear()
|
||||
session.Save()
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(UserIDKey, user.ID)
|
||||
c.Set(UserRoleKey, user.Role)
|
||||
c.Set(UserNameKey, user.Name)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func AdminRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get(UserRoleKey)
|
||||
if !exists || role != "admin" {
|
||||
c.HTML(http.StatusForbidden, "error.html", gin.H{
|
||||
"title": "アクセス拒否",
|
||||
"message": "この操作には管理者権限が必要です。",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func GuestOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
userID := session.Get(UserIDKey)
|
||||
|
||||
if userID != nil {
|
||||
c.Redirect(http.StatusFound, "/")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func InjectUserInfo() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
userID := session.Get(UserIDKey)
|
||||
|
||||
if userID != nil {
|
||||
c.Set(UserIDKey, userID.(uint))
|
||||
c.Set(UserRoleKey, session.Get(UserRoleKey))
|
||||
c.Set(UserNameKey, session.Get(UserNameKey))
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type APIKeyValidator interface {
|
||||
ValidateAPIKey(key string) (uint, error)
|
||||
}
|
||||
|
||||
func APIKeyAuth(validator APIKeyValidator) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
if len(authHeader) <= len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization format. Use: Bearer <api_key>"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := authHeader[len(bearerPrefix):]
|
||||
|
||||
userID, err := validator.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(UserIDKey, userID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
119
internal/middleware/csrf.go
Normal file
119
internal/middleware/csrf.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
csrfTokenKey = "csrf_token"
|
||||
csrfTokenFormKey = "_csrf"
|
||||
csrfTokenHeader = "X-CSRF-Token"
|
||||
)
|
||||
|
||||
type CSRFConfig struct {
|
||||
Secret string
|
||||
}
|
||||
|
||||
func generateCSRFToken(secret string) (string, error) {
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write(randomBytes)
|
||||
signature := h.Sum(nil)
|
||||
|
||||
token := append(randomBytes, signature...)
|
||||
return base64.URLEncoding.EncodeToString(token), nil
|
||||
}
|
||||
|
||||
func validateCSRFToken(token, secret string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(decoded) != 64 {
|
||||
return false
|
||||
}
|
||||
|
||||
randomBytes := decoded[:32]
|
||||
providedSignature := decoded[32:]
|
||||
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write(randomBytes)
|
||||
expectedSignature := h.Sum(nil)
|
||||
|
||||
return hmac.Equal(providedSignature, expectedSignature)
|
||||
}
|
||||
|
||||
func CSRF(config CSRFConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
csrfToken, ok := session.Get(csrfTokenKey).(string)
|
||||
if !ok || csrfToken == "" || !validateCSRFToken(csrfToken, config.Secret) {
|
||||
newToken, err := generateCSRFToken(config.Secret)
|
||||
if err != nil {
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
csrfToken = newToken
|
||||
session.Set(csrfTokenKey, csrfToken)
|
||||
session.Save()
|
||||
}
|
||||
|
||||
c.Set(csrfTokenKey, csrfToken)
|
||||
|
||||
method := strings.ToUpper(c.Request.Method)
|
||||
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
submittedToken := c.PostForm(csrfTokenFormKey)
|
||||
if submittedToken == "" {
|
||||
submittedToken = c.GetHeader(csrfTokenHeader)
|
||||
}
|
||||
|
||||
sessionToken := session.Get(csrfTokenKey)
|
||||
if sessionToken == nil || submittedToken != sessionToken.(string) {
|
||||
c.HTML(http.StatusForbidden, "error.html", gin.H{
|
||||
"title": "CSRFエラー",
|
||||
"message": "不正なリクエストです。ページを再読み込みしてください。",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
newToken, err := generateCSRFToken(config.Secret)
|
||||
if err == nil {
|
||||
session.Set(csrfTokenKey, newToken)
|
||||
session.Save()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CSRFField(c *gin.Context) template.HTML {
|
||||
token, exists := c.Get(csrfTokenKey)
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
return template.HTML(`<input type="hidden" name="` + csrfTokenFormKey + `" value="` + token.(string) + `">`)
|
||||
}
|
||||
94
internal/middleware/ratelimit.go
Normal file
94
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool
|
||||
Requests int
|
||||
Window int
|
||||
}
|
||||
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
entries map[string]*rateLimitEntry
|
||||
mu sync.Mutex
|
||||
config RateLimitConfig
|
||||
}
|
||||
|
||||
func newRateLimiter(config RateLimitConfig) *rateLimiter {
|
||||
rl := &rateLimiter{
|
||||
entries: make(map[string]*rateLimitEntry),
|
||||
config: config,
|
||||
}
|
||||
|
||||
go rl.cleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, entry := range rl.entries {
|
||||
if now.After(entry.expiresAt) {
|
||||
delete(rl.entries, key)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, exists := rl.entries[key]
|
||||
|
||||
if !exists || now.After(entry.expiresAt) {
|
||||
rl.entries[key] = &rateLimitEntry{
|
||||
count: 1,
|
||||
expiresAt: now.Add(time.Duration(rl.config.Window) * time.Second),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return entry.count <= rl.config.Requests
|
||||
}
|
||||
|
||||
func RateLimit(config RateLimitConfig) gin.HandlerFunc {
|
||||
if !config.Enabled {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
limiter := newRateLimiter(config)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
if !limiter.allow(clientIP) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "リクエスト数が制限を超えました。しばらくしてからお試しください。",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
52
internal/middleware/security.go
Normal file
52
internal/middleware/security.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type SecurityConfig struct {
|
||||
HTTPS bool
|
||||
}
|
||||
func SecurityHeaders(config SecurityConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.HTTPS {
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
csp := []string{
|
||||
"default-src 'self'",
|
||||
"script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net",
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net",
|
||||
"font-src 'self' https://cdn.jsdelivr.net",
|
||||
"img-src 'self' data:",
|
||||
"connect-src 'self'",
|
||||
"frame-ancestors 'none'",
|
||||
}
|
||||
c.Header("Content-Security-Policy", strings.Join(csp, "; "))
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ForceHTTPS(config SecurityConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.HTTPS && c.Request.TLS == nil && c.Request.Header.Get("X-Forwarded-Proto") != "https" {
|
||||
|
||||
host := c.Request.Host
|
||||
target := "https://" + host + c.Request.URL.Path
|
||||
if len(c.Request.URL.RawQuery) > 0 {
|
||||
target += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
c.Redirect(301, target)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
14
internal/middleware/timer.go
Normal file
14
internal/middleware/timer.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RequestTimer() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set("startTime", time.Now())
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user