first commit

This commit is contained in:
2025-12-30 21:47:39 +09:00
commit 0a37314fa8
47 changed files with 6088 additions and 0 deletions

182
internal/config/config.go Normal file
View File

@@ -0,0 +1,182 @@
package config
import (
"log"
"os"
"gopkg.in/ini.v1"
)
type DatabaseConfig struct {
Driver string // sqlite, mysql, postgres
Path string // SQLiteの場合パス
Host string
Port string
User string
Password string
Name string
}
type Config struct {
Port string
SessionSecret string
Debug bool
AllowRegistration bool
HTTPS bool
CSRFSecret string
RateLimitEnabled bool
RateLimitRequests int
RateLimitWindow int
TrustedProxies []string
Database DatabaseConfig
}
func Load(configPath string) *Config {
cfg := &Config{
Port: "8080",
SessionSecret: "",
Debug: true,
AllowRegistration: true,
HTTPS: false,
CSRFSecret: "",
RateLimitEnabled: true,
RateLimitRequests: 100,
RateLimitWindow: 60,
Database: DatabaseConfig{
Driver: "sqlite",
Path: "homework.db",
Host: "localhost",
Port: "3306",
User: "root",
Password: "",
Name: "homework_manager",
},
}
if configPath == "" {
configPath = "config.ini"
}
if iniFile, err := ini.Load(configPath); err == nil {
log.Printf("Loading configuration from %s", configPath)
section := iniFile.Section("server")
if section.HasKey("port") {
cfg.Port = section.Key("port").String()
}
if section.HasKey("debug") {
cfg.Debug = section.Key("debug").MustBool(true)
}
section = iniFile.Section("database")
if section.HasKey("driver") {
cfg.Database.Driver = section.Key("driver").String()
}
if section.HasKey("path") {
cfg.Database.Path = section.Key("path").String()
}
if section.HasKey("host") {
cfg.Database.Host = section.Key("host").String()
}
if section.HasKey("port") {
cfg.Database.Port = section.Key("port").String()
}
if section.HasKey("user") {
cfg.Database.User = section.Key("user").String()
}
if section.HasKey("password") {
cfg.Database.Password = section.Key("password").String()
}
if section.HasKey("name") {
cfg.Database.Name = section.Key("name").String()
}
section = iniFile.Section("session")
if section.HasKey("secret") {
cfg.SessionSecret = section.Key("secret").String()
}
section = iniFile.Section("auth")
if section.HasKey("allow_registration") {
cfg.AllowRegistration = section.Key("allow_registration").MustBool(true)
}
section = iniFile.Section("security")
if section.HasKey("https") {
cfg.HTTPS = section.Key("https").MustBool(false)
}
if section.HasKey("csrf_secret") {
cfg.CSRFSecret = section.Key("csrf_secret").String()
}
if section.HasKey("rate_limit_enabled") {
cfg.RateLimitEnabled = section.Key("rate_limit_enabled").MustBool(true)
}
if section.HasKey("rate_limit_requests") {
cfg.RateLimitRequests = section.Key("rate_limit_requests").MustInt(100)
}
if section.HasKey("rate_limit_window") {
cfg.RateLimitWindow = section.Key("rate_limit_window").MustInt(60)
}
if section.HasKey("trusted_proxies") {
proxies := section.Key("trusted_proxies").String()
if proxies != "" {
cfg.TrustedProxies = []string{proxies}
}
}
} else {
log.Println("config.ini not found, using environment variables or defaults")
}
if port := os.Getenv("PORT"); port != "" {
cfg.Port = port
}
if dbDriver := os.Getenv("DATABASE_DRIVER"); dbDriver != "" {
cfg.Database.Driver = dbDriver
}
if dbPath := os.Getenv("DATABASE_PATH"); dbPath != "" {
cfg.Database.Path = dbPath
}
if dbHost := os.Getenv("DATABASE_HOST"); dbHost != "" {
cfg.Database.Host = dbHost
}
if dbPort := os.Getenv("DATABASE_PORT"); dbPort != "" {
cfg.Database.Port = dbPort
}
if dbUser := os.Getenv("DATABASE_USER"); dbUser != "" {
cfg.Database.User = dbUser
}
if dbPassword := os.Getenv("DATABASE_PASSWORD"); dbPassword != "" {
cfg.Database.Password = dbPassword
}
if dbName := os.Getenv("DATABASE_NAME"); dbName != "" {
cfg.Database.Name = dbName
}
if sessionSecret := os.Getenv("SESSION_SECRET"); sessionSecret != "" {
cfg.SessionSecret = sessionSecret
}
if os.Getenv("GIN_MODE") == "release" {
cfg.Debug = false
}
if allowReg := os.Getenv("ALLOW_REGISTRATION"); allowReg != "" {
cfg.AllowRegistration = allowReg == "true" || allowReg == "1"
}
if https := os.Getenv("HTTPS"); https != "" {
cfg.HTTPS = https == "true" || https == "1"
}
if csrfSecret := os.Getenv("CSRF_SECRET"); csrfSecret != "" {
cfg.CSRFSecret = csrfSecret
}
if trustedProxies := os.Getenv("TRUSTED_PROXIES"); trustedProxies != "" {
cfg.TrustedProxies = []string{trustedProxies}
}
if cfg.SessionSecret == "" {
log.Fatal("FATAL: Session secret is not set. Please set it in config.ini ([session] secret) or via SESSION_SECRET environment variable.")
}
if cfg.CSRFSecret == "" {
log.Fatal("FATAL: CSRF secret is not set. Please set it in config.ini ([security] csrf_secret) or via CSRF_SECRET environment variable.")
}
return cfg
}

View File

@@ -0,0 +1,79 @@
package database
import (
"fmt"
"homework-manager/internal/config"
"homework-manager/internal/models"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var DB *gorm.DB
func Connect(dbConfig config.DatabaseConfig, debug bool) error {
var logMode logger.LogLevel
if debug {
logMode = logger.Info
} else {
logMode = logger.Silent
}
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logMode),
}
var db *gorm.DB
var err error
switch dbConfig.Driver {
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
dbConfig.User,
dbConfig.Password,
dbConfig.Host,
dbConfig.Port,
dbConfig.Name,
)
db, err = gorm.Open(mysql.Open(dsn), gormConfig)
case "postgres":
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
dbConfig.Host,
dbConfig.Port,
dbConfig.User,
dbConfig.Password,
dbConfig.Name,
)
db, err = gorm.Open(postgres.Open(dsn), gormConfig)
case "sqlite":
fallthrough
default:
db, err = gorm.Open(sqlite.Open(dbConfig.Path), gormConfig)
}
if err != nil {
return err
}
DB = db
return nil
}
func Migrate() error {
return DB.AutoMigrate(
&models.User{},
&models.Assignment{},
&models.APIKey{},
)
}
func GetDB() *gorm.DB {
return DB
}

View File

@@ -0,0 +1,165 @@
package handler
import (
"net/http"
"strconv"
"homework-manager/internal/middleware"
"homework-manager/internal/service"
"github.com/gin-gonic/gin"
)
type AdminHandler struct {
adminService *service.AdminService
apiKeyService *service.APIKeyService
}
func NewAdminHandler() *AdminHandler {
return &AdminHandler{
adminService: service.NewAdminService(),
apiKeyService: service.NewAPIKeyService(),
}
}
func (h *AdminHandler) getUserID(c *gin.Context) uint {
userID, _ := c.Get(middleware.UserIDKey)
return userID.(uint)
}
func (h *AdminHandler) Index(c *gin.Context) {
users, _ := h.adminService.GetAllUsers()
currentUserID := h.getUserID(c)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "admin/users.html", gin.H{
"title": "ユーザー管理",
"users": users,
"currentUserID": currentUserID,
"isAdmin": true,
"userName": name,
})
}
func (h *AdminHandler) DeleteUser(c *gin.Context) {
adminID := h.getUserID(c)
targetID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "無効なユーザーID"})
return
}
err = h.adminService.DeleteUser(adminID, uint(targetID))
if err != nil {
users, _ := h.adminService.GetAllUsers()
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "admin/users.html", gin.H{
"title": "ユーザー管理",
"users": users,
"currentUserID": adminID,
"error": err.Error(),
"isAdmin": true,
"userName": name,
})
return
}
c.Redirect(http.StatusFound, "/admin/users")
}
func (h *AdminHandler) ChangeRole(c *gin.Context) {
adminID := h.getUserID(c)
targetID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "無効なユーザーID"})
return
}
newRole := c.PostForm("role")
err = h.adminService.ChangeRole(adminID, uint(targetID), newRole)
if err != nil {
users, _ := h.adminService.GetAllUsers()
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "admin/users.html", gin.H{
"title": "ユーザー管理",
"users": users,
"currentUserID": adminID,
"error": err.Error(),
"isAdmin": true,
"userName": name,
})
return
}
c.Redirect(http.StatusFound, "/admin/users")
}
func (h *AdminHandler) APIKeys(c *gin.Context) {
keys, _ := h.apiKeyService.GetAllAPIKeys()
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "admin/api_keys.html", gin.H{
"title": "APIキー管理",
"apiKeys": keys,
"isAdmin": true,
"userName": name,
})
}
func (h *AdminHandler) CreateAPIKey(c *gin.Context) {
userID := h.getUserID(c)
keyName := c.PostForm("name")
plainKey, _, err := h.apiKeyService.CreateAPIKey(userID, keyName)
keys, _ := h.apiKeyService.GetAllAPIKeys()
name, _ := c.Get(middleware.UserNameKey)
if err != nil {
RenderHTML(c, http.StatusOK, "admin/api_keys.html", gin.H{
"title": "APIキー管理",
"apiKeys": keys,
"error": err.Error(),
"isAdmin": true,
"userName": name,
})
return
}
RenderHTML(c, http.StatusOK, "admin/api_keys.html", gin.H{
"title": "APIキー管理",
"apiKeys": keys,
"newKey": plainKey,
"newKeyName": keyName,
"isAdmin": true,
"userName": name,
})
}
func (h *AdminHandler) DeleteAPIKey(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "無効なAPIキーID"})
return
}
err = h.apiKeyService.DeleteAPIKey(uint(id))
if err != nil {
keys, _ := h.apiKeyService.GetAllAPIKeys()
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "admin/api_keys.html", gin.H{
"title": "APIキー管理",
"apiKeys": keys,
"error": err.Error(),
"isAdmin": true,
"userName": name,
})
return
}
c.Redirect(http.StatusFound, "/admin/api-keys")
}

View File

@@ -0,0 +1,398 @@
package handler
import (
"net/http"
"strconv"
"time"
"homework-manager/internal/middleware"
"homework-manager/internal/service"
"github.com/gin-gonic/gin"
)
type APIHandler struct {
assignmentService *service.AssignmentService
}
func NewAPIHandler() *APIHandler {
return &APIHandler{
assignmentService: service.NewAssignmentService(),
}
}
func (h *APIHandler) getUserID(c *gin.Context) uint {
userID, _ := c.Get(middleware.UserIDKey)
return userID.(uint)
}
// ListAssignments returns all assignments for the authenticated user with pagination
// GET /api/v1/assignments?filter=pending&page=1&page_size=20
func (h *APIHandler) ListAssignments(c *gin.Context) {
userID := h.getUserID(c)
filter := c.Query("filter") // pending, completed, overdue
// Parse pagination parameters
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
// Validate pagination parameters
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100 // Maximum page size to prevent abuse
}
// Use paginated methods for filtered queries
switch filter {
case "completed":
result, err := h.assignmentService.GetCompletedByUserPaginated(userID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
c.JSON(http.StatusOK, gin.H{
"assignments": result.Assignments,
"count": len(result.Assignments),
"total_count": result.TotalCount,
"total_pages": result.TotalPages,
"current_page": result.CurrentPage,
"page_size": result.PageSize,
})
return
case "overdue":
result, err := h.assignmentService.GetOverdueByUserPaginated(userID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
c.JSON(http.StatusOK, gin.H{
"assignments": result.Assignments,
"count": len(result.Assignments),
"total_count": result.TotalCount,
"total_pages": result.TotalPages,
"current_page": result.CurrentPage,
"page_size": result.PageSize,
})
return
case "pending":
result, err := h.assignmentService.GetPendingByUserPaginated(userID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
c.JSON(http.StatusOK, gin.H{
"assignments": result.Assignments,
"count": len(result.Assignments),
"total_count": result.TotalCount,
"total_pages": result.TotalPages,
"current_page": result.CurrentPage,
"page_size": result.PageSize,
})
return
default:
// For "all" filter, use simple pagination without a dedicated method
assignments, err := h.assignmentService.GetAllByUser(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
// Manual pagination for all assignments
totalCount := len(assignments)
totalPages := (totalCount + pageSize - 1) / pageSize
start := (page - 1) * pageSize
end := start + pageSize
if start > totalCount {
start = totalCount
}
if end > totalCount {
end = totalCount
}
c.JSON(http.StatusOK, gin.H{
"assignments": assignments[start:end],
"count": end - start,
"total_count": totalCount,
"total_pages": totalPages,
"current_page": page,
"page_size": pageSize,
})
}
}
// ListPendingAssignments returns pending assignments with pagination
// GET /api/v1/assignments/pending?page=1&page_size=20
func (h *APIHandler) ListPendingAssignments(c *gin.Context) {
userID := h.getUserID(c)
page, pageSize := h.parsePagination(c)
result, err := h.assignmentService.GetPendingByUserPaginated(userID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
h.sendPaginatedResponse(c, result)
}
// ListCompletedAssignments returns completed assignments with pagination
// GET /api/v1/assignments/completed?page=1&page_size=20
func (h *APIHandler) ListCompletedAssignments(c *gin.Context) {
userID := h.getUserID(c)
page, pageSize := h.parsePagination(c)
result, err := h.assignmentService.GetCompletedByUserPaginated(userID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
h.sendPaginatedResponse(c, result)
}
// ListOverdueAssignments returns overdue assignments with pagination
// GET /api/v1/assignments/overdue?page=1&page_size=20
func (h *APIHandler) ListOverdueAssignments(c *gin.Context) {
userID := h.getUserID(c)
page, pageSize := h.parsePagination(c)
result, err := h.assignmentService.GetOverdueByUserPaginated(userID, page, pageSize)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
h.sendPaginatedResponse(c, result)
}
// ListDueTodayAssignments returns assignments due today
// GET /api/v1/assignments/due-today
func (h *APIHandler) ListDueTodayAssignments(c *gin.Context) {
userID := h.getUserID(c)
assignments, err := h.assignmentService.GetDueTodayByUser(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
c.JSON(http.StatusOK, gin.H{
"assignments": assignments,
"count": len(assignments),
})
}
// ListDueThisWeekAssignments returns assignments due within this week
// GET /api/v1/assignments/due-this-week
func (h *APIHandler) ListDueThisWeekAssignments(c *gin.Context) {
userID := h.getUserID(c)
assignments, err := h.assignmentService.GetDueThisWeekByUser(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch assignments"})
return
}
c.JSON(http.StatusOK, gin.H{
"assignments": assignments,
"count": len(assignments),
})
}
// parsePagination extracts and validates pagination parameters
func (h *APIHandler) parsePagination(c *gin.Context) (page int, pageSize int) {
page, _ = strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ = strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100
}
return page, pageSize
}
// sendPaginatedResponse sends a standard paginated JSON response
func (h *APIHandler) sendPaginatedResponse(c *gin.Context, result *service.PaginatedResult) {
c.JSON(http.StatusOK, gin.H{
"assignments": result.Assignments,
"count": len(result.Assignments),
"total_count": result.TotalCount,
"total_pages": result.TotalPages,
"current_page": result.CurrentPage,
"page_size": result.PageSize,
})
}
// GetAssignment returns a single assignment by ID
// GET /api/v1/assignments/:id
func (h *APIHandler) GetAssignment(c *gin.Context) {
userID := h.getUserID(c)
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid assignment ID"})
return
}
assignment, err := h.assignmentService.GetByID(userID, uint(id))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Assignment not found"})
return
}
c.JSON(http.StatusOK, assignment)
}
// CreateAssignmentInput represents the JSON input for creating an assignment
type CreateAssignmentInput struct {
Title string `json:"title" binding:"required"`
Description string `json:"description"`
Subject string `json:"subject"`
Priority string `json:"priority"` // low, medium, high (default: medium)
DueDate string `json:"due_date" binding:"required"` // RFC3339 or 2006-01-02T15:04
}
// CreateAssignment creates a new assignment
// POST /api/v1/assignments
func (h *APIHandler) CreateAssignment(c *gin.Context) {
userID := h.getUserID(c)
var input CreateAssignmentInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input: title and due_date are required"})
return
}
dueDate, err := time.Parse(time.RFC3339, input.DueDate)
if err != nil {
dueDate, err = time.ParseInLocation("2006-01-02T15:04", input.DueDate, time.Local)
if err != nil {
dueDate, err = time.ParseInLocation("2006-01-02", input.DueDate, time.Local)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid due_date format. Use RFC3339 or 2006-01-02T15:04"})
return
}
dueDate = dueDate.Add(23*time.Hour + 59*time.Minute)
}
}
assignment, err := h.assignmentService.Create(userID, input.Title, input.Description, input.Subject, input.Priority, dueDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create assignment"})
return
}
c.JSON(http.StatusCreated, assignment)
}
// UpdateAssignmentInput represents the JSON input for updating an assignment
type UpdateAssignmentInput struct {
Title string `json:"title"`
Description string `json:"description"`
Subject string `json:"subject"`
Priority string `json:"priority"`
DueDate string `json:"due_date"`
}
// UpdateAssignment updates an existing assignment
// PUT /api/v1/assignments/:id
func (h *APIHandler) UpdateAssignment(c *gin.Context) {
userID := h.getUserID(c)
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid assignment ID"})
return
}
// Get existing assignment
existing, err := h.assignmentService.GetByID(userID, uint(id))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Assignment not found"})
return
}
var input UpdateAssignmentInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input"})
return
}
// Use existing values if not provided
title := input.Title
if title == "" {
title = existing.Title
}
description := input.Description
subject := input.Subject
priority := input.Priority
if priority == "" {
priority = existing.Priority
}
dueDate := existing.DueDate
if input.DueDate != "" {
dueDate, err = time.Parse(time.RFC3339, input.DueDate)
if err != nil {
dueDate, err = time.ParseInLocation("2006-01-02T15:04", input.DueDate, time.Local)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid due_date format"})
return
}
}
}
assignment, err := h.assignmentService.Update(userID, uint(id), title, description, subject, priority, dueDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update assignment"})
return
}
c.JSON(http.StatusOK, assignment)
}
// DeleteAssignment deletes an assignment
// DELETE /api/v1/assignments/:id
func (h *APIHandler) DeleteAssignment(c *gin.Context) {
userID := h.getUserID(c)
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid assignment ID"})
return
}
if err := h.assignmentService.Delete(userID, uint(id)); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Assignment not found"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Assignment deleted"})
}
// ToggleAssignment toggles the completion status of an assignment
// PATCH /api/v1/assignments/:id/toggle
func (h *APIHandler) ToggleAssignment(c *gin.Context) {
userID := h.getUserID(c)
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid assignment ID"})
return
}
assignment, err := h.assignmentService.ToggleComplete(userID, uint(id))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Assignment not found"})
return
}
c.JSON(http.StatusOK, assignment)
}

View File

@@ -0,0 +1,233 @@
package handler
import (
"net/http"
"strconv"
"strings"
"time"
"homework-manager/internal/middleware"
"homework-manager/internal/models"
"homework-manager/internal/service"
"github.com/gin-gonic/gin"
)
type AssignmentHandler struct {
assignmentService *service.AssignmentService
}
func NewAssignmentHandler() *AssignmentHandler {
return &AssignmentHandler{
assignmentService: service.NewAssignmentService(),
}
}
func (h *AssignmentHandler) getUserID(c *gin.Context) uint {
userID, _ := c.Get(middleware.UserIDKey)
return userID.(uint)
}
func (h *AssignmentHandler) Dashboard(c *gin.Context) {
userID := h.getUserID(c)
stats, _ := h.assignmentService.GetDashboardStats(userID)
dueToday, _ := h.assignmentService.GetDueTodayByUser(userID)
overdue, _ := h.assignmentService.GetOverdueByUser(userID)
upcoming, _ := h.assignmentService.GetDueThisWeekByUser(userID)
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "dashboard.html", gin.H{
"title": "ダッシュボード",
"stats": stats,
"dueToday": dueToday,
"overdue": overdue,
"upcoming": upcoming,
"isAdmin": role == "admin",
"userName": name,
})
}
func (h *AssignmentHandler) Index(c *gin.Context) {
userID := h.getUserID(c)
filter := c.Query("filter")
filter = strings.TrimSpace(filter)
if filter == "" {
filter = "pending"
}
query := c.Query("q")
priority := c.Query("priority")
pageStr := c.DefaultQuery("page", "1")
page, err := strconv.Atoi(pageStr)
if err != nil || page < 1 {
page = 1
}
const pageSize = 10
result, err := h.assignmentService.SearchAssignments(userID, query, priority, filter, page, pageSize)
var assignments []models.Assignment
var totalPages, currentPage int
if err != nil || result == nil {
assignments = []models.Assignment{}
totalPages = 1
currentPage = 1
} else {
assignments = result.Assignments
totalPages = result.TotalPages
currentPage = result.CurrentPage
}
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "assignments/index.html", gin.H{
"title": "課題一覧",
"assignments": assignments,
"filter": filter,
"query": query,
"priority": priority,
"isAdmin": role == "admin",
"userName": name,
"currentPage": currentPage,
"totalPages": totalPages,
"hasPrev": currentPage > 1,
"hasNext": currentPage < totalPages,
"prevPage": currentPage - 1,
"nextPage": currentPage + 1,
})
}
func (h *AssignmentHandler) New(c *gin.Context) {
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "assignments/new.html", gin.H{
"title": "課題登録",
"isAdmin": role == "admin",
"userName": name,
})
}
func (h *AssignmentHandler) Create(c *gin.Context) {
userID := h.getUserID(c)
title := c.PostForm("title")
description := c.PostForm("description")
subject := c.PostForm("subject")
priority := c.PostForm("priority")
dueDateStr := c.PostForm("due_date")
dueDate, err := time.ParseInLocation("2006-01-02T15:04", dueDateStr, time.Local)
if err != nil {
dueDate, err = time.ParseInLocation("2006-01-02", dueDateStr, time.Local)
if err != nil {
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "assignments/new.html", gin.H{
"title": "課題登録",
"error": "提出期限の形式が正しくありません",
"formTitle": title,
"description": description,
"subject": subject,
"priority": priority,
"isAdmin": role == "admin",
"userName": name,
})
return
}
dueDate = dueDate.Add(23*time.Hour + 59*time.Minute)
}
_, err = h.assignmentService.Create(userID, title, description, subject, priority, dueDate)
if err != nil {
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "assignments/new.html", gin.H{
"title": "課題登録",
"error": "課題の登録に失敗しました",
"formTitle": title,
"description": description,
"subject": subject,
"priority": priority,
"isAdmin": role == "admin",
"userName": name,
})
return
}
c.Redirect(http.StatusFound, "/assignments")
}
func (h *AssignmentHandler) Edit(c *gin.Context) {
userID := h.getUserID(c)
id, _ := strconv.ParseUint(c.Param("id"), 10, 32)
assignment, err := h.assignmentService.GetByID(userID, uint(id))
if err != nil {
c.Redirect(http.StatusFound, "/assignments")
return
}
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "assignments/edit.html", gin.H{
"title": "課題編集",
"assignment": assignment,
"isAdmin": role == "admin",
"userName": name,
})
}
func (h *AssignmentHandler) Update(c *gin.Context) {
userID := h.getUserID(c)
id, _ := strconv.ParseUint(c.Param("id"), 10, 32)
title := c.PostForm("title")
description := c.PostForm("description")
subject := c.PostForm("subject")
priority := c.PostForm("priority")
dueDateStr := c.PostForm("due_date")
dueDate, err := time.ParseInLocation("2006-01-02T15:04", dueDateStr, time.Local)
if err != nil {
dueDate, err = time.ParseInLocation("2006-01-02", dueDateStr, time.Local)
if err != nil {
c.Redirect(http.StatusFound, "/assignments")
return
}
dueDate = dueDate.Add(23*time.Hour + 59*time.Minute)
}
_, err = h.assignmentService.Update(userID, uint(id), title, description, subject, priority, dueDate)
if err != nil {
c.Redirect(http.StatusFound, "/assignments")
return
}
c.Redirect(http.StatusFound, "/assignments")
}
func (h *AssignmentHandler) Toggle(c *gin.Context) {
userID := h.getUserID(c)
id, _ := strconv.ParseUint(c.Param("id"), 10, 32)
h.assignmentService.ToggleComplete(userID, uint(id))
referer := c.Request.Referer()
if referer == "" {
referer = "/assignments"
}
c.Redirect(http.StatusFound, referer)
}
func (h *AssignmentHandler) Delete(c *gin.Context) {
userID := h.getUserID(c)
id, _ := strconv.ParseUint(c.Param("id"), 10, 32)
h.assignmentService.Delete(userID, uint(id))
c.Redirect(http.StatusFound, "/assignments")
}

View File

@@ -0,0 +1,114 @@
package handler
import (
"net/http"
"homework-manager/internal/middleware"
"homework-manager/internal/service"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type AuthHandler struct {
authService *service.AuthService
}
func NewAuthHandler() *AuthHandler {
return &AuthHandler{
authService: service.NewAuthService(),
}
}
func (h *AuthHandler) ShowLogin(c *gin.Context) {
RenderHTML(c, http.StatusOK, "login.html", gin.H{
"title": "ログイン",
})
}
func (h *AuthHandler) Login(c *gin.Context) {
email := c.PostForm("email")
password := c.PostForm("password")
user, err := h.authService.Login(email, password)
if err != nil {
RenderHTML(c, http.StatusOK, "login.html", gin.H{
"title": "ログイン",
"error": "メールアドレスまたはパスワードが正しくありません",
"email": email,
})
return
}
session := sessions.Default(c)
session.Set(middleware.UserIDKey, user.ID)
session.Set(middleware.UserRoleKey, user.Role)
session.Set(middleware.UserNameKey, user.Name)
session.Save()
c.Redirect(http.StatusFound, "/")
}
func (h *AuthHandler) ShowRegister(c *gin.Context) {
RenderHTML(c, http.StatusOK, "register.html", gin.H{
"title": "新規登録",
})
}
func (h *AuthHandler) Register(c *gin.Context) {
email := c.PostForm("email")
password := c.PostForm("password")
passwordConfirm := c.PostForm("password_confirm")
name := c.PostForm("name")
if password != passwordConfirm {
RenderHTML(c, http.StatusOK, "register.html", gin.H{
"title": "新規登録",
"error": "パスワードが一致しません",
"email": email,
"name": name,
})
return
}
if len(password) < 8 {
RenderHTML(c, http.StatusOK, "register.html", gin.H{
"title": "新規登録",
"error": "パスワードは8文字以上で入力してください",
"email": email,
"name": name,
})
return
}
user, err := h.authService.Register(email, password, name)
if err != nil {
errorMsg := "登録に失敗しました"
if err == service.ErrEmailAlreadyExists {
errorMsg = "このメールアドレスは既に使用されています"
}
RenderHTML(c, http.StatusOK, "register.html", gin.H{
"title": "新規登録",
"error": errorMsg,
"email": email,
"name": name,
})
return
}
session := sessions.Default(c)
session.Set(middleware.UserIDKey, user.ID)
session.Set(middleware.UserRoleKey, user.Role)
session.Set(middleware.UserNameKey, user.Name)
session.Save()
c.Redirect(http.StatusFound, "/")
}
func (h *AuthHandler) Logout(c *gin.Context) {
session := sessions.Default(c)
session.Clear()
session.Save()
c.Redirect(http.StatusFound, "/login")
}

View File

@@ -0,0 +1,33 @@
package handler
import (
"fmt"
"html/template"
"time"
"github.com/gin-gonic/gin"
)
const csrfTokenKey = "csrf_token"
const csrfTokenFormKey = "_csrf"
func RenderHTML(c *gin.Context, code int, name string, obj gin.H) {
if obj == nil {
obj = gin.H{}
}
if startTime, exists := c.Get("startTime"); exists {
duration := time.Since(startTime.(time.Time))
obj["processing_time"] = fmt.Sprintf("%.2fms", float64(duration.Microseconds())/1000.0)
} else {
obj["processing_time"] = "unknown"
}
if token, exists := c.Get(csrfTokenKey); exists {
obj["csrfToken"] = token.(string)
obj["csrfField"] = template.HTML(`<input type="hidden" name="` + csrfTokenFormKey + `" value="` + token.(string) + `">`)
}
c.HTML(code, name, obj)
}

View File

@@ -0,0 +1,122 @@
package handler
import (
"net/http"
"homework-manager/internal/middleware"
"homework-manager/internal/service"
"github.com/gin-gonic/gin"
)
type ProfileHandler struct {
authService *service.AuthService
}
func NewProfileHandler() *ProfileHandler {
return &ProfileHandler{
authService: service.NewAuthService(),
}
}
func (h *ProfileHandler) getUserID(c *gin.Context) uint {
userID, _ := c.Get(middleware.UserIDKey)
return userID.(uint)
}
func (h *ProfileHandler) Show(c *gin.Context) {
userID := h.getUserID(c)
user, _ := h.authService.GetUserByID(userID)
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"isAdmin": role == "admin",
"userName": name,
})
}
func (h *ProfileHandler) Update(c *gin.Context) {
userID := h.getUserID(c)
name := c.PostForm("name")
err := h.authService.UpdateProfile(userID, name)
role, _ := c.Get(middleware.UserRoleKey)
user, _ := h.authService.GetUserByID(userID)
if err != nil {
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"error": "プロフィールの更新に失敗しました",
"isAdmin": role == "admin",
"userName": name,
})
return
}
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"success": "プロフィールを更新しました",
"isAdmin": role == "admin",
"userName": user.Name,
})
}
func (h *ProfileHandler) ChangePassword(c *gin.Context) {
userID := h.getUserID(c)
oldPassword := c.PostForm("old_password")
newPassword := c.PostForm("new_password")
confirmPassword := c.PostForm("confirm_password")
role, _ := c.Get(middleware.UserRoleKey)
name, _ := c.Get(middleware.UserNameKey)
user, _ := h.authService.GetUserByID(userID)
if newPassword != confirmPassword {
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"passwordError": "新しいパスワードが一致しません",
"isAdmin": role == "admin",
"userName": name,
})
return
}
if len(newPassword) < 8 {
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"passwordError": "パスワードは8文字以上で入力してください",
"isAdmin": role == "admin",
"userName": name,
})
return
}
err := h.authService.ChangePassword(userID, oldPassword, newPassword)
if err != nil {
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"passwordError": "現在のパスワードが正しくありません",
"isAdmin": role == "admin",
"userName": name,
})
return
}
RenderHTML(c, http.StatusOK, "profile.html", gin.H{
"title": "プロフィール",
"user": user,
"passwordSuccess": "パスワードを変更しました",
"isAdmin": role == "admin",
"userName": name,
})
}

121
internal/middleware/auth.go Normal file
View File

@@ -0,0 +1,121 @@
package middleware
import (
"net/http"
"homework-manager/internal/service"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
const UserIDKey = "user_id"
const UserRoleKey = "user_role"
const UserNameKey = "user_name"
func AuthRequired(authService *service.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID == nil {
c.Redirect(http.StatusFound, "/login")
c.Abort()
return
}
user, err := authService.GetUserByID(userID.(uint))
if err != nil {
session.Clear()
session.Save()
c.Redirect(http.StatusFound, "/login")
c.Abort()
return
}
c.Set(UserIDKey, user.ID)
c.Set(UserRoleKey, user.Role)
c.Set(UserNameKey, user.Name)
c.Next()
}
}
func AdminRequired() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get(UserRoleKey)
if !exists || role != "admin" {
c.HTML(http.StatusForbidden, "error.html", gin.H{
"title": "アクセス拒否",
"message": "この操作には管理者権限が必要です。",
})
c.Abort()
return
}
c.Next()
}
}
func GuestOnly() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID != nil {
c.Redirect(http.StatusFound, "/")
c.Abort()
return
}
c.Next()
}
}
func InjectUserInfo() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID != nil {
c.Set(UserIDKey, userID.(uint))
c.Set(UserRoleKey, session.Get(UserRoleKey))
c.Set(UserNameKey, session.Get(UserNameKey))
}
c.Next()
}
}
type APIKeyValidator interface {
ValidateAPIKey(key string) (uint, error)
}
func APIKeyAuth(validator APIKeyValidator) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
c.Abort()
return
}
const bearerPrefix = "Bearer "
if len(authHeader) <= len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization format. Use: Bearer <api_key>"})
c.Abort()
return
}
apiKey := authHeader[len(bearerPrefix):]
userID, err := validator.ValidateAPIKey(apiKey)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
c.Abort()
return
}
c.Set(UserIDKey, userID)
c.Next()
}
}

119
internal/middleware/csrf.go Normal file
View File

@@ -0,0 +1,119 @@
package middleware
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"html/template"
"net/http"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
const (
csrfTokenKey = "csrf_token"
csrfTokenFormKey = "_csrf"
csrfTokenHeader = "X-CSRF-Token"
)
type CSRFConfig struct {
Secret string
}
func generateCSRFToken(secret string) (string, error) {
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
}
h := hmac.New(sha256.New, []byte(secret))
h.Write(randomBytes)
signature := h.Sum(nil)
token := append(randomBytes, signature...)
return base64.URLEncoding.EncodeToString(token), nil
}
func validateCSRFToken(token, secret string) bool {
if token == "" {
return false
}
decoded, err := base64.URLEncoding.DecodeString(token)
if err != nil {
return false
}
if len(decoded) != 64 {
return false
}
randomBytes := decoded[:32]
providedSignature := decoded[32:]
h := hmac.New(sha256.New, []byte(secret))
h.Write(randomBytes)
expectedSignature := h.Sum(nil)
return hmac.Equal(providedSignature, expectedSignature)
}
func CSRF(config CSRFConfig) gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
csrfToken, ok := session.Get(csrfTokenKey).(string)
if !ok || csrfToken == "" || !validateCSRFToken(csrfToken, config.Secret) {
newToken, err := generateCSRFToken(config.Secret)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
csrfToken = newToken
session.Set(csrfTokenKey, csrfToken)
session.Save()
}
c.Set(csrfTokenKey, csrfToken)
method := strings.ToUpper(c.Request.Method)
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
c.Next()
return
}
submittedToken := c.PostForm(csrfTokenFormKey)
if submittedToken == "" {
submittedToken = c.GetHeader(csrfTokenHeader)
}
sessionToken := session.Get(csrfTokenKey)
if sessionToken == nil || submittedToken != sessionToken.(string) {
c.HTML(http.StatusForbidden, "error.html", gin.H{
"title": "CSRFエラー",
"message": "不正なリクエストです。ページを再読み込みしてください。",
})
c.Abort()
return
}
c.Next()
newToken, err := generateCSRFToken(config.Secret)
if err == nil {
session.Set(csrfTokenKey, newToken)
session.Save()
}
}
}
func CSRFField(c *gin.Context) template.HTML {
token, exists := c.Get(csrfTokenKey)
if !exists {
return ""
}
return template.HTML(`<input type="hidden" name="` + csrfTokenFormKey + `" value="` + token.(string) + `">`)
}

View File

@@ -0,0 +1,94 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
type RateLimitConfig struct {
Enabled bool
Requests int
Window int
}
type rateLimitEntry struct {
count int
expiresAt time.Time
}
type rateLimiter struct {
entries map[string]*rateLimitEntry
mu sync.Mutex
config RateLimitConfig
}
func newRateLimiter(config RateLimitConfig) *rateLimiter {
rl := &rateLimiter{
entries: make(map[string]*rateLimitEntry),
config: config,
}
go rl.cleanup()
return rl
}
func (rl *rateLimiter) cleanup() {
ticker := time.NewTicker(time.Minute)
for range ticker.C {
rl.mu.Lock()
now := time.Now()
for key, entry := range rl.entries {
if now.After(entry.expiresAt) {
delete(rl.entries, key)
}
}
rl.mu.Unlock()
}
}
func (rl *rateLimiter) allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
entry, exists := rl.entries[key]
if !exists || now.After(entry.expiresAt) {
rl.entries[key] = &rateLimitEntry{
count: 1,
expiresAt: now.Add(time.Duration(rl.config.Window) * time.Second),
}
return true
}
entry.count++
return entry.count <= rl.config.Requests
}
func RateLimit(config RateLimitConfig) gin.HandlerFunc {
if !config.Enabled {
return func(c *gin.Context) {
c.Next()
}
}
limiter := newRateLimiter(config)
return func(c *gin.Context) {
clientIP := c.ClientIP()
if !limiter.allow(clientIP) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "リクエスト数が制限を超えました。しばらくしてからお試しください。",
})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,52 @@
package middleware
import (
"strings"
"github.com/gin-gonic/gin"
)
type SecurityConfig struct {
HTTPS bool
}
func SecurityHeaders(config SecurityConfig) gin.HandlerFunc {
return func(c *gin.Context) {
if config.HTTPS {
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
csp := []string{
"default-src 'self'",
"script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net",
"style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net",
"font-src 'self' https://cdn.jsdelivr.net",
"img-src 'self' data:",
"connect-src 'self'",
"frame-ancestors 'none'",
}
c.Header("Content-Security-Policy", strings.Join(csp, "; "))
c.Header("X-Frame-Options", "DENY")
c.Header("X-Content-Type-Options", "nosniff")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("X-XSS-Protection", "1; mode=block")
c.Next()
}
}
func ForceHTTPS(config SecurityConfig) gin.HandlerFunc {
return func(c *gin.Context) {
if config.HTTPS && c.Request.TLS == nil && c.Request.Header.Get("X-Forwarded-Proto") != "https" {
host := c.Request.Host
target := "https://" + host + c.Request.URL.Path
if len(c.Request.URL.RawQuery) > 0 {
target += "?" + c.Request.URL.RawQuery
}
c.Redirect(301, target)
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,14 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
)
func RequestTimer() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("startTime", time.Now())
c.Next()
}
}

View File

@@ -0,0 +1,19 @@
package models
import (
"time"
"gorm.io/gorm"
)
type APIKey struct {
ID uint `gorm:"primarykey" json:"id"`
UserID uint `gorm:"not null;index" json:"user_id"`
Name string `gorm:"not null" json:"name"`
KeyHash string `gorm:"not null;uniqueIndex" json:"-"`
LastUsed *time.Time `json:"last_used,omitempty"`
CreatedAt time.Time `json:"created_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
}

View File

@@ -0,0 +1,41 @@
package models
import (
"time"
"gorm.io/gorm"
)
type Assignment struct {
ID uint `gorm:"primarykey" json:"id"`
UserID uint `gorm:"not null;index" json:"user_id"`
Title string `gorm:"not null" json:"title"`
Description string `json:"description"`
Subject string `json:"subject"`
Priority string `gorm:"not null;default:medium" json:"priority"` // low, medium, high
DueDate time.Time `gorm:"not null" json:"due_date"`
IsCompleted bool `gorm:"default:false" json:"is_completed"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
}
func (a *Assignment) IsOverdue() bool {
return !a.IsCompleted && time.Now().After(a.DueDate)
}
func (a *Assignment) IsDueToday() bool {
now := time.Now()
return a.DueDate.Year() == now.Year() &&
a.DueDate.Month() == now.Month() &&
a.DueDate.Day() == now.Day()
}
func (a *Assignment) IsDueThisWeek() bool {
now := time.Now()
weekLater := now.AddDate(0, 0, 7)
return a.DueDate.After(now) && a.DueDate.Before(weekLater)
}

28
internal/models/user.go Normal file
View File

@@ -0,0 +1,28 @@
package models
import (
"time"
"gorm.io/gorm"
)
type User struct {
ID uint `gorm:"primarykey" json:"id"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
PasswordHash string `gorm:"not null" json:"-"`
Name string `gorm:"not null" json:"name"`
Role string `gorm:"not null;default:user" json:"role"` // "admin" or "user"
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
Assignments []Assignment `gorm:"foreignKey:UserID" json:"assignments,omitempty"`
}
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
func (u *User) GetID() uint {
return u.ID
}

View File

@@ -0,0 +1,188 @@
package repository
import (
"time"
"homework-manager/internal/database"
"homework-manager/internal/models"
"gorm.io/gorm"
)
type AssignmentRepository struct {
db *gorm.DB
}
func NewAssignmentRepository() *AssignmentRepository {
return &AssignmentRepository{db: database.GetDB()}
}
func (r *AssignmentRepository) Create(assignment *models.Assignment) error {
return r.db.Create(assignment).Error
}
func (r *AssignmentRepository) FindByID(id uint) (*models.Assignment, error) {
var assignment models.Assignment
err := r.db.First(&assignment, id).Error
if err != nil {
return nil, err
}
return &assignment, nil
}
func (r *AssignmentRepository) FindByUserID(userID uint) ([]models.Assignment, error) {
var assignments []models.Assignment
err := r.db.Where("user_id = ?", userID).Order("due_date ASC").Find(&assignments).Error
return assignments, err
}
func (r *AssignmentRepository) FindPendingByUserID(userID uint, limit, offset int) ([]models.Assignment, error) {
var assignments []models.Assignment
query := r.db.Where("user_id = ? AND is_completed = ?", userID, false).
Order("due_date ASC")
if limit > 0 {
query = query.Limit(limit).Offset(offset)
}
err := query.Find(&assignments).Error
return assignments, err
}
func (r *AssignmentRepository) FindCompletedByUserID(userID uint, limit, offset int) ([]models.Assignment, error) {
var assignments []models.Assignment
query := r.db.Where("user_id = ? AND is_completed = ?", userID, true).
Order("completed_at DESC")
if limit > 0 {
query = query.Limit(limit).Offset(offset)
}
err := query.Find(&assignments).Error
return assignments, err
}
func (r *AssignmentRepository) FindDueTodayByUserID(userID uint) ([]models.Assignment, error) {
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
endOfDay := startOfDay.AddDate(0, 0, 1)
var assignments []models.Assignment
err := r.db.Where("user_id = ? AND is_completed = ? AND due_date >= ? AND due_date < ?",
userID, false, startOfDay, endOfDay).
Order("due_date ASC").Find(&assignments).Error
return assignments, err
}
func (r *AssignmentRepository) FindDueThisWeekByUserID(userID uint) ([]models.Assignment, error) {
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
weekLater := startOfDay.AddDate(0, 0, 7)
var assignments []models.Assignment
err := r.db.Where("user_id = ? AND is_completed = ? AND due_date >= ? AND due_date < ?",
userID, false, startOfDay, weekLater).
Order("due_date ASC").Find(&assignments).Error
return assignments, err
}
func (r *AssignmentRepository) FindOverdueByUserID(userID uint, limit, offset int) ([]models.Assignment, error) {
now := time.Now()
var assignments []models.Assignment
query := r.db.Where("user_id = ? AND is_completed = ? AND due_date < ?",
userID, false, now).
Order("due_date ASC")
if limit > 0 {
query = query.Limit(limit).Offset(offset)
}
err := query.Find(&assignments).Error
return assignments, err
}
func (r *AssignmentRepository) Update(assignment *models.Assignment) error {
return r.db.Save(assignment).Error
}
func (r *AssignmentRepository) Delete(id uint) error {
return r.db.Delete(&models.Assignment{}, id).Error
}
func (r *AssignmentRepository) CountByUserID(userID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Assignment{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *AssignmentRepository) CountPendingByUserID(userID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Assignment{}).
Where("user_id = ? AND is_completed = ?", userID, false).Count(&count).Error
return count, err
}
func (r *AssignmentRepository) GetSubjectsByUserID(userID uint) ([]string, error) {
var subjects []string
err := r.db.Model(&models.Assignment{}).
Where("user_id = ? AND subject != ''", userID).
Distinct("subject").
Pluck("subject", &subjects).Error
return subjects, err
}
func (r *AssignmentRepository) CountCompletedByUserID(userID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Assignment{}).
Where("user_id = ? AND is_completed = ?", userID, true).Count(&count).Error
return count, err
}
func (r *AssignmentRepository) Search(userID uint, queryStr, priority, filter string, page, pageSize int) ([]models.Assignment, int64, error) {
var assignments []models.Assignment
var totalCount int64
dbQuery := r.db.Model(&models.Assignment{}).Where("user_id = ?", userID)
if queryStr != "" {
dbQuery = dbQuery.Where("title LIKE ? OR description LIKE ?", "%"+queryStr+"%", "%"+queryStr+"%")
}
if priority != "" {
dbQuery = dbQuery.Where("priority = ?", priority)
}
now := time.Now()
switch filter {
case "completed":
dbQuery = dbQuery.Where("is_completed = ?", true)
case "overdue":
dbQuery = dbQuery.Where("is_completed = ? AND due_date < ?", false, now)
default: // pending
dbQuery = dbQuery.Where("is_completed = ?", false)
}
if err := dbQuery.Count(&totalCount).Error; err != nil {
return nil, 0, err
}
if filter == "completed" {
dbQuery = dbQuery.Order("completed_at DESC")
} else {
dbQuery = dbQuery.Order("due_date ASC")
}
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
offset := (page - 1) * pageSize
err := dbQuery.Limit(pageSize).Offset(offset).Find(&assignments).Error
return assignments, totalCount, err
}
func (r *AssignmentRepository) CountOverdueByUserID(userID uint) (int64, error) {
var count int64
now := time.Now()
err := r.db.Model(&models.Assignment{}).
Where("user_id = ? AND is_completed = ? AND due_date < ?", userID, false, now).Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,61 @@
package repository
import (
"homework-manager/internal/database"
"homework-manager/internal/models"
"gorm.io/gorm"
)
type UserRepository struct {
db *gorm.DB
}
func NewUserRepository() *UserRepository {
return &UserRepository{db: database.GetDB()}
}
func (r *UserRepository) Create(user *models.User) error {
return r.db.Create(user).Error
}
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
var user models.User
err := r.db.First(&user, id).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) FindByEmail(email string) (*models.User, error) {
var user models.User
err := r.db.Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) FindAll() ([]models.User, error) {
var users []models.User
err := r.db.Find(&users).Error
return users, err
}
func (r *UserRepository) Update(user *models.User) error {
return r.db.Save(user).Error
}
func (r *UserRepository) Delete(id uint) error {
if err := r.db.Unscoped().Where("user_id = ?", id).Delete(&models.Assignment{}).Error; err != nil {
return err
}
return r.db.Unscoped().Delete(&models.User{}, id).Error
}
func (r *UserRepository) Count() (int64, error) {
var count int64
err := r.db.Model(&models.User{}).Count(&count).Error
return count, err
}

241
internal/router/router.go Normal file
View File

@@ -0,0 +1,241 @@
package router
import (
"html/template"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"homework-manager/internal/config"
"homework-manager/internal/handler"
"homework-manager/internal/middleware"
"homework-manager/internal/service"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
)
func getFuncMap() template.FuncMap {
return template.FuncMap{
"formatDate": func(t time.Time) string {
return t.Format("2006/01/02")
},
"formatDateTime": func(t time.Time) string {
return t.Format("2006/01/02 15:04")
},
"formatDateInput": func(t time.Time) string {
return t.Format("2006-01-02T15:04")
},
"isOverdue": func(t time.Time, completed bool) bool {
return !completed && time.Now().After(t)
},
"daysUntil": func(t time.Time) int {
return int(time.Until(t).Hours() / 24)
},
}
}
func loadTemplates() (*template.Template, error) {
tmpl := template.New("").Funcs(getFuncMap())
baseContent, err := os.ReadFile("web/templates/layouts/base.html")
if err != nil {
return nil, err
}
templateDirs := []struct {
pattern string
prefix string
}{
{"web/templates/auth/*.html", ""},
{"web/templates/pages/*.html", ""},
{"web/templates/assignments/*.html", "assignments/"},
{"web/templates/admin/*.html", "admin/"},
}
for _, dir := range templateDirs {
files, err := filepath.Glob(dir.pattern)
if err != nil {
return nil, err
}
for _, file := range files {
name := dir.prefix + filepath.Base(file)
content, err := os.ReadFile(file)
if err != nil {
return nil, err
}
reDefine := regexp.MustCompile(`{{\s*define\s+"([^"]+)"\s*}}`)
reTemplate := regexp.MustCompile(`{{\s*template\s+"([^"]+)"\s*([^}]*)\s*}}`)
uniqueBase := reDefine.ReplaceAllStringFunc(string(baseContent), func(m string) string {
match := reDefine.FindStringSubmatch(m)
blockName := match[1]
if blockName == "head" || blockName == "scripts" || blockName == "content" || blockName == "base" {
return strings.Replace(m, blockName, name+"_"+blockName, 1)
}
return m
})
uniqueBase = reTemplate.ReplaceAllStringFunc(uniqueBase, func(m string) string {
match := reTemplate.FindStringSubmatch(m)
blockName := match[1]
if blockName == "head" || blockName == "scripts" || blockName == "content" || blockName == "base" {
return strings.Replace(m, blockName, name+"_"+blockName, 1)
}
return m
})
uniqueContent := reDefine.ReplaceAllStringFunc(string(content), func(m string) string {
match := reDefine.FindStringSubmatch(m)
blockName := match[1]
if blockName == "head" || blockName == "scripts" || blockName == "content" {
return strings.Replace(m, blockName, name+"_"+blockName, 1)
}
return m
})
uniqueContent = reTemplate.ReplaceAllStringFunc(uniqueContent, func(m string) string {
match := reTemplate.FindStringSubmatch(m)
blockName := match[1]
if blockName == "base" {
return strings.Replace(m, blockName, name+"_"+blockName, 1)
}
return m
})
combined := uniqueBase + "\n" + uniqueContent
_, err = tmpl.New(name).Parse(combined)
if err != nil {
return nil, err
}
}
}
return tmpl, nil
}
func Setup(cfg *config.Config) *gin.Engine {
if !cfg.Debug {
gin.SetMode(gin.ReleaseMode)
}
r := gin.Default()
if len(cfg.TrustedProxies) > 0 {
r.SetTrustedProxies(cfg.TrustedProxies)
}
tmpl, err := loadTemplates()
if err != nil {
panic("Failed to load templates: " + err.Error())
}
r.SetHTMLTemplate(tmpl)
r.Static("/static", "web/static")
store := cookie.NewStore([]byte(cfg.SessionSecret))
store.Options(sessions.Options{
Path: "/",
MaxAge: 86400 * 7, // 7 days
HttpOnly: true,
Secure: cfg.HTTPS,
SameSite: http.SameSiteLaxMode,
})
r.Use(sessions.Sessions("session", store))
r.Use(middleware.RequestTimer())
securityConfig := middleware.SecurityConfig{
HTTPS: cfg.HTTPS,
}
r.Use(middleware.SecurityHeaders(securityConfig))
r.Use(middleware.ForceHTTPS(securityConfig))
r.Use(middleware.RateLimit(middleware.RateLimitConfig{
Enabled: cfg.RateLimitEnabled,
Requests: cfg.RateLimitRequests,
Window: cfg.RateLimitWindow,
}))
csrfMiddleware := middleware.CSRF(middleware.CSRFConfig{
Secret: cfg.CSRFSecret,
})
authService := service.NewAuthService()
apiKeyService := service.NewAPIKeyService()
authHandler := handler.NewAuthHandler()
assignmentHandler := handler.NewAssignmentHandler()
adminHandler := handler.NewAdminHandler()
profileHandler := handler.NewProfileHandler()
apiHandler := handler.NewAPIHandler()
guest := r.Group("/")
guest.Use(middleware.GuestOnly())
guest.Use(csrfMiddleware)
{
guest.GET("/login", authHandler.ShowLogin)
guest.POST("/login", authHandler.Login)
if cfg.AllowRegistration {
guest.GET("/register", authHandler.ShowRegister)
guest.POST("/register", authHandler.Register)
} else {
guest.GET("/register", func(c *gin.Context) {
c.HTML(http.StatusForbidden, "error.html", gin.H{
"title": "登録無効",
"message": "新規登録は現在受け付けておりません。",
})
})
}
}
auth := r.Group("/")
auth.Use(middleware.AuthRequired(authService))
auth.Use(csrfMiddleware)
{
auth.GET("/", assignmentHandler.Dashboard)
auth.POST("/logout", authHandler.Logout)
auth.GET("/assignments", assignmentHandler.Index)
auth.GET("/assignments/new", assignmentHandler.New)
auth.POST("/assignments", assignmentHandler.Create)
auth.GET("/assignments/:id/edit", assignmentHandler.Edit)
auth.POST("/assignments/:id", assignmentHandler.Update)
auth.POST("/assignments/:id/toggle", assignmentHandler.Toggle)
auth.POST("/assignments/:id/delete", assignmentHandler.Delete)
auth.GET("/profile", profileHandler.Show)
auth.POST("/profile", profileHandler.Update)
auth.POST("/profile/password", profileHandler.ChangePassword)
admin := auth.Group("/admin")
admin.Use(middleware.AdminRequired())
{
admin.GET("/users", adminHandler.Index)
admin.POST("/users/:id/delete", adminHandler.DeleteUser)
admin.POST("/users/:id/role", adminHandler.ChangeRole)
admin.GET("/api-keys", adminHandler.APIKeys)
admin.POST("/api-keys", adminHandler.CreateAPIKey)
admin.POST("/api-keys/:id/delete", adminHandler.DeleteAPIKey)
}
}
api := r.Group("/api/v1")
api.Use(middleware.APIKeyAuth(apiKeyService))
{
api.GET("/assignments", apiHandler.ListAssignments)
api.GET("/assignments/pending", apiHandler.ListPendingAssignments)
api.GET("/assignments/completed", apiHandler.ListCompletedAssignments)
api.GET("/assignments/overdue", apiHandler.ListOverdueAssignments)
api.GET("/assignments/due-today", apiHandler.ListDueTodayAssignments)
api.GET("/assignments/due-this-week", apiHandler.ListDueThisWeekAssignments)
api.GET("/assignments/:id", apiHandler.GetAssignment)
api.POST("/assignments", apiHandler.CreateAssignment)
api.PUT("/assignments/:id", apiHandler.UpdateAssignment)
api.DELETE("/assignments/:id", apiHandler.DeleteAssignment)
api.PATCH("/assignments/:id/toggle", apiHandler.ToggleAssignment)
}
return r
}

View File

@@ -0,0 +1,62 @@
package service
import (
"errors"
"homework-manager/internal/models"
"homework-manager/internal/repository"
)
var (
ErrCannotDeleteSelf = errors.New("cannot delete yourself")
ErrCannotChangeSelfRole = errors.New("cannot change your own role")
)
type AdminService struct {
userRepo *repository.UserRepository
}
func NewAdminService() *AdminService {
return &AdminService{
userRepo: repository.NewUserRepository(),
}
}
func (s *AdminService) GetAllUsers() ([]models.User, error) {
return s.userRepo.FindAll()
}
func (s *AdminService) GetUserByID(id uint) (*models.User, error) {
return s.userRepo.FindByID(id)
}
func (s *AdminService) DeleteUser(adminID, targetID uint) error {
if adminID == targetID {
return ErrCannotDeleteSelf
}
_, err := s.userRepo.FindByID(targetID)
if err != nil {
return ErrUserNotFound
}
return s.userRepo.Delete(targetID)
}
func (s *AdminService) ChangeRole(adminID, targetID uint, newRole string) error {
if adminID == targetID {
return ErrCannotChangeSelfRole
}
if newRole != "admin" && newRole != "user" {
return errors.New("invalid role")
}
user, err := s.userRepo.FindByID(targetID)
if err != nil {
return ErrUserNotFound
}
user.Role = newRole
return s.userRepo.Update(user)
}

View File

@@ -0,0 +1,89 @@
package service
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"time"
"homework-manager/internal/database"
"homework-manager/internal/models"
)
type APIKeyService struct{}
func NewAPIKeyService() *APIKeyService {
return &APIKeyService{}
}
func (s *APIKeyService) generateRandomKey() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return "hm_" + hex.EncodeToString(bytes), nil
}
func (s *APIKeyService) hashKey(key string) string {
hash := sha256.Sum256([]byte(key))
return hex.EncodeToString(hash[:])
}
func (s *APIKeyService) CreateAPIKey(userID uint, name string) (string, *models.APIKey, error) {
if name == "" {
return "", nil, errors.New("キー名を入力してください")
}
plainKey, err := s.generateRandomKey()
if err != nil {
return "", nil, errors.New("キーの生成に失敗しました")
}
apiKey := &models.APIKey{
UserID: userID,
Name: name,
KeyHash: s.hashKey(plainKey),
}
if err := database.GetDB().Create(apiKey).Error; err != nil {
return "", nil, errors.New("キーの保存に失敗しました")
}
return plainKey, apiKey, nil
}
func (s *APIKeyService) ValidateAPIKey(plainKey string) (uint, error) {
hash := s.hashKey(plainKey)
var apiKey models.APIKey
if err := database.GetDB().Where("key_hash = ?", hash).First(&apiKey).Error; err != nil {
return 0, errors.New("無効なAPIキーです")
}
now := time.Now()
database.GetDB().Model(&apiKey).Update("last_used", now)
return apiKey.UserID, nil
}
func (s *APIKeyService) GetAllAPIKeys() ([]models.APIKey, error) {
var keys []models.APIKey
err := database.GetDB().Preload("User").Order("created_at desc").Find(&keys).Error
return keys, err
}
func (s *APIKeyService) GetAPIKeysByUser(userID uint) ([]models.APIKey, error) {
var keys []models.APIKey
err := database.GetDB().Where("user_id = ?", userID).Order("created_at desc").Find(&keys).Error
return keys, err
}
func (s *APIKeyService) DeleteAPIKey(id uint) error {
result := database.GetDB().Delete(&models.APIKey{}, id)
if result.RowsAffected == 0 {
return errors.New("APIキーが見つかりません")
}
return result.Error
}

View File

@@ -0,0 +1,269 @@
package service
import (
"errors"
"time"
"homework-manager/internal/models"
"homework-manager/internal/repository"
)
var (
ErrAssignmentNotFound = errors.New("assignment not found")
ErrUnauthorized = errors.New("unauthorized")
)
type PaginatedResult struct {
Assignments []models.Assignment
TotalCount int64
TotalPages int
CurrentPage int
PageSize int
}
type AssignmentService struct {
assignmentRepo *repository.AssignmentRepository
}
func NewAssignmentService() *AssignmentService {
return &AssignmentService{
assignmentRepo: repository.NewAssignmentRepository(),
}
}
func (s *AssignmentService) Create(userID uint, title, description, subject, priority string, dueDate time.Time) (*models.Assignment, error) {
if priority == "" {
priority = "medium"
}
assignment := &models.Assignment{
UserID: userID,
Title: title,
Description: description,
Subject: subject,
Priority: priority,
DueDate: dueDate,
IsCompleted: false,
}
if err := s.assignmentRepo.Create(assignment); err != nil {
return nil, err
}
return assignment, nil
}
func (s *AssignmentService) GetByID(userID, assignmentID uint) (*models.Assignment, error) {
assignment, err := s.assignmentRepo.FindByID(assignmentID)
if err != nil {
return nil, ErrAssignmentNotFound
}
if assignment.UserID != userID {
return nil, ErrUnauthorized
}
return assignment, nil
}
func (s *AssignmentService) GetAllByUser(userID uint) ([]models.Assignment, error) {
return s.assignmentRepo.FindByUserID(userID)
}
func (s *AssignmentService) GetPendingByUser(userID uint) ([]models.Assignment, error) {
return s.assignmentRepo.FindPendingByUserID(userID, 0, 0)
}
func (s *AssignmentService) GetPendingByUserPaginated(userID uint, page, pageSize int) (*PaginatedResult, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
offset := (page - 1) * pageSize
assignments, err := s.assignmentRepo.FindPendingByUserID(userID, pageSize, offset)
if err != nil {
return nil, err
}
totalCount, _ := s.assignmentRepo.CountPendingByUserID(userID)
totalPages := int((totalCount + int64(pageSize) - 1) / int64(pageSize))
return &PaginatedResult{
Assignments: assignments,
TotalCount: totalCount,
TotalPages: totalPages,
CurrentPage: page,
PageSize: pageSize,
}, nil
}
func (s *AssignmentService) GetCompletedByUser(userID uint) ([]models.Assignment, error) {
return s.assignmentRepo.FindCompletedByUserID(userID, 0, 0)
}
func (s *AssignmentService) GetCompletedByUserPaginated(userID uint, page, pageSize int) (*PaginatedResult, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
offset := (page - 1) * pageSize
assignments, err := s.assignmentRepo.FindCompletedByUserID(userID, pageSize, offset)
if err != nil {
return nil, err
}
totalCount, _ := s.assignmentRepo.CountCompletedByUserID(userID)
totalPages := int((totalCount + int64(pageSize) - 1) / int64(pageSize))
return &PaginatedResult{
Assignments: assignments,
TotalCount: totalCount,
TotalPages: totalPages,
CurrentPage: page,
PageSize: pageSize,
}, nil
}
func (s *AssignmentService) GetDueTodayByUser(userID uint) ([]models.Assignment, error) {
return s.assignmentRepo.FindDueTodayByUserID(userID)
}
func (s *AssignmentService) GetDueThisWeekByUser(userID uint) ([]models.Assignment, error) {
return s.assignmentRepo.FindDueThisWeekByUserID(userID)
}
func (s *AssignmentService) GetOverdueByUser(userID uint) ([]models.Assignment, error) {
return s.assignmentRepo.FindOverdueByUserID(userID, 0, 0)
}
func (s *AssignmentService) GetOverdueByUserPaginated(userID uint, page, pageSize int) (*PaginatedResult, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
offset := (page - 1) * pageSize
assignments, err := s.assignmentRepo.FindOverdueByUserID(userID, pageSize, offset)
if err != nil {
return nil, err
}
totalCount, _ := s.assignmentRepo.CountOverdueByUserID(userID)
totalPages := int((totalCount + int64(pageSize) - 1) / int64(pageSize))
return &PaginatedResult{
Assignments: assignments,
TotalCount: totalCount,
TotalPages: totalPages,
CurrentPage: page,
PageSize: pageSize,
}, nil
}
func (s *AssignmentService) SearchAssignments(userID uint, query, priority, filter string, page, pageSize int) (*PaginatedResult, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
assignments, totalCount, err := s.assignmentRepo.Search(userID, query, priority, filter, page, pageSize)
if err != nil {
return nil, err
}
totalPages := int((totalCount + int64(pageSize) - 1) / int64(pageSize))
return &PaginatedResult{
Assignments: assignments,
TotalCount: totalCount,
TotalPages: totalPages,
CurrentPage: page,
PageSize: pageSize,
}, nil
}
func (s *AssignmentService) Update(userID, assignmentID uint, title, description, subject, priority string, dueDate time.Time) (*models.Assignment, error) {
assignment, err := s.GetByID(userID, assignmentID)
if err != nil {
return nil, err
}
assignment.Title = title
assignment.Description = description
assignment.Subject = subject
assignment.Priority = priority
assignment.DueDate = dueDate
if err := s.assignmentRepo.Update(assignment); err != nil {
return nil, err
}
return assignment, nil
}
func (s *AssignmentService) ToggleComplete(userID, assignmentID uint) (*models.Assignment, error) {
assignment, err := s.GetByID(userID, assignmentID)
if err != nil {
return nil, err
}
assignment.IsCompleted = !assignment.IsCompleted
if assignment.IsCompleted {
now := time.Now()
assignment.CompletedAt = &now
} else {
assignment.CompletedAt = nil
}
if err := s.assignmentRepo.Update(assignment); err != nil {
return nil, err
}
return assignment, nil
}
func (s *AssignmentService) Delete(userID, assignmentID uint) error {
assignment, err := s.GetByID(userID, assignmentID)
if err != nil {
return err
}
return s.assignmentRepo.Delete(assignment.ID)
}
func (s *AssignmentService) GetSubjectsByUser(userID uint) ([]string, error) {
return s.assignmentRepo.GetSubjectsByUserID(userID)
}
type DashboardStats struct {
TotalPending int64
DueToday int
DueThisWeek int
Overdue int
Subjects []string
}
func (s *AssignmentService) GetDashboardStats(userID uint) (*DashboardStats, error) {
pending, _ := s.assignmentRepo.CountPendingByUserID(userID)
dueToday, _ := s.assignmentRepo.FindDueTodayByUserID(userID)
dueThisWeek, _ := s.assignmentRepo.FindDueThisWeekByUserID(userID)
overdueCount, _ := s.assignmentRepo.CountOverdueByUserID(userID)
subjects, _ := s.assignmentRepo.GetSubjectsByUserID(userID)
return &DashboardStats{
TotalPending: pending,
DueToday: len(dueToday),
DueThisWeek: len(dueThisWeek),
Overdue: int(overdueCount),
Subjects: subjects,
}, nil
}

View File

@@ -0,0 +1,106 @@
package service
import (
"errors"
"homework-manager/internal/models"
"homework-manager/internal/repository"
"golang.org/x/crypto/bcrypt"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrEmailAlreadyExists = errors.New("email already exists")
ErrInvalidCredentials = errors.New("invalid credentials")
)
type AuthService struct {
userRepo *repository.UserRepository
}
func NewAuthService() *AuthService {
return &AuthService{
userRepo: repository.NewUserRepository(),
}
}
func (s *AuthService) Register(email, password, name string) (*models.User, error) {
// Check if email already exists
_, err := s.userRepo.FindByEmail(email)
if err == nil {
return nil, ErrEmailAlreadyExists
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
// Determine role (first user is admin)
role := "user"
count, _ := s.userRepo.Count()
if count == 0 {
role = "admin"
}
user := &models.User{
Email: email,
PasswordHash: string(hashedPassword),
Name: name,
Role: role,
}
if err := s.userRepo.Create(user); err != nil {
return nil, err
}
return user, nil
}
func (s *AuthService) Login(email, password string) (*models.User, error) {
user, err := s.userRepo.FindByEmail(email)
if err != nil {
return nil, ErrInvalidCredentials
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return nil, ErrInvalidCredentials
}
return user, nil
}
func (s *AuthService) GetUserByID(id uint) (*models.User, error) {
return s.userRepo.FindByID(id)
}
func (s *AuthService) ChangePassword(userID uint, oldPassword, newPassword string) error {
user, err := s.userRepo.FindByID(userID)
if err != nil {
return ErrUserNotFound
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword)); err != nil {
return ErrInvalidCredentials
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hashedPassword)
return s.userRepo.Update(user)
}
func (s *AuthService) UpdateProfile(userID uint, name string) error {
user, err := s.userRepo.FindByID(userID)
if err != nil {
return ErrUserNotFound
}
user.Name = name
return s.userRepo.Update(user)
}