Files
memory/server/internal/handler/auth.go
2025-12-14 20:33:33 +08:00

148 lines
3.9 KiB
Go

package handler
import (
"database/sql"
"net/http"
"time"
"memory/internal/config"
"memory/internal/middleware"
"memory/internal/model"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
)
type AuthHandler struct {
db *sql.DB
cfg *config.Config
}
func NewAuthHandler(db *sql.DB, cfg *config.Config) *AuthHandler {
return &AuthHandler{db: db, cfg: cfg}
}
func (h *AuthHandler) Register(c *gin.Context) {
var req model.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查是否允许注册
var allowRegister string
err := h.db.QueryRow("SELECT value FROM settings WHERE key = 'allow_register'").Scan(&allowRegister)
if err == nil && allowRegister == "false" {
// 检查是否有用户存在(第一个用户可以注册)
var count int
h.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
if count > 0 {
c.JSON(http.StatusForbidden, gin.H{"error": "registration is disabled"})
return
}
}
// 检查用户名是否已存在
var exists int
h.db.QueryRow("SELECT COUNT(*) FROM users WHERE username = ?", req.Username).Scan(&exists)
if exists > 0 {
c.JSON(http.StatusConflict, gin.H{"error": "username already exists"})
return
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to hash password"})
return
}
// 检查是否是第一个用户(设为管理员)
var userCount int
h.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&userCount)
isAdmin := userCount == 0
// 创建用户
result, err := h.db.Exec(
"INSERT INTO users (username, password_hash, nickname, is_admin) VALUES (?, ?, ?, ?)",
req.Username, string(hashedPassword), req.Nickname, isAdmin,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create user"})
return
}
userID, _ := result.LastInsertId()
// 生成 token
token, err := h.generateToken(userID, isAdmin)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate token"})
return
}
c.JSON(http.StatusCreated, model.LoginResponse{
Token: token,
User: &model.User{
ID: userID,
Username: req.Username,
Nickname: req.Nickname,
IsAdmin: isAdmin,
},
})
}
func (h *AuthHandler) Login(c *gin.Context) {
var req model.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var user model.User
err := h.db.QueryRow(
"SELECT id, username, password_hash, nickname, avatar_url, bio, is_admin, created_at FROM users WHERE username = ?",
req.Username,
).Scan(&user.ID, &user.Username, &user.PasswordHash, &user.Nickname, &user.AvatarURL, &user.Bio, &user.IsAdmin, &user.CreatedAt)
if err == sql.ErrNoRows {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"})
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database error"})
return
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"})
return
}
token, err := h.generateToken(user.ID, user.IsAdmin)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate token"})
return
}
c.JSON(http.StatusOK, model.LoginResponse{
Token: token,
User: &user,
})
}
func (h *AuthHandler) generateToken(userID int64, isAdmin bool) (string, error) {
claims := &middleware.Claims{
UserID: userID,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(7 * 24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(h.cfg.JWTSecret))
}