Files
memory/server/internal/handler/upload.go
2025-12-14 20:33:33 +08:00

172 lines
3.9 KiB
Go

package handler
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"memory/internal/config"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type UploadHandler struct {
cfg *config.Config
s3Client *s3.Client
}
func NewUploadHandler(cfg *config.Config) *UploadHandler {
h := &UploadHandler{cfg: cfg}
if cfg.R2AccountID != "" && cfg.R2AccessKeyID != "" {
// 初始化 R2 客户端
r2Resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{
URL: fmt.Sprintf("https://%s.r2.cloudflarestorage.com", cfg.R2AccountID),
}, nil
})
awsCfg, err := awsconfig.LoadDefaultConfig(context.TODO(),
awsconfig.WithEndpointResolverWithOptions(r2Resolver),
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.R2AccessKeyID,
cfg.R2AccessKeySecret,
"",
)),
awsconfig.WithRegion("auto"),
)
if err == nil {
h.s3Client = s3.NewFromConfig(awsCfg)
}
}
return h
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "no file uploaded"})
return
}
// 检查文件类型
ext := strings.ToLower(filepath.Ext(file.Filename))
allowedExts := map[string]string{
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".mp4": "video/mp4",
".mov": "video/quicktime",
}
contentType, ok := allowedExts[ext]
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported file type"})
return
}
// 检查文件大小 (最大 50MB)
if file.Size > 50*1024*1024 {
c.JSON(http.StatusBadRequest, gin.H{"error": "file too large (max 50MB)"})
return
}
// 生成唯一文件名
filename := fmt.Sprintf("%s/%s%s",
time.Now().Format("2006/01"),
uuid.New().String(),
ext,
)
// 打开文件
src, err := file.Open()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read file"})
return
}
defer src.Close()
// 优先使用本地存储
if h.cfg.LocalUploadPath != "" {
url, err := h.saveLocal(src, filename)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save file"})
return
}
c.JSON(http.StatusOK, gin.H{
"url": url,
"filename": filename,
})
return
}
// 使用 R2 存储
if h.s3Client != nil {
url, err := h.uploadToR2(src, filename, contentType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to upload file"})
return
}
c.JSON(http.StatusOK, gin.H{
"url": url,
"filename": filename,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage not configured"})
}
func (h *UploadHandler) saveLocal(src io.Reader, filename string) (string, error) {
// 创建目录
fullPath := filepath.Join(h.cfg.LocalUploadPath, filename)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", err
}
// 创建文件
dst, err := os.Create(fullPath)
if err != nil {
return "", err
}
defer dst.Close()
// 复制内容
if _, err := io.Copy(dst, src); err != nil {
return "", err
}
// 返回 URL
url := fmt.Sprintf("%s/uploads/%s", strings.TrimSuffix(h.cfg.BaseURL, "/"), filename)
return url, nil
}
func (h *UploadHandler) uploadToR2(src io.Reader, filename, contentType string) (string, error) {
_, err := h.s3Client.PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(h.cfg.R2BucketName),
Key: aws.String(filename),
Body: src,
ContentType: aws.String(contentType),
})
if err != nil {
return "", err
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(h.cfg.R2PublicURL, "/"), filename)
return url, nil
}