Files
memory/server/internal/handler/upload.go

221 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"memory/internal/config"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type UploadHandler struct {
cfg *config.Config
s3Client *s3.Client
}
func NewUploadHandler(cfg *config.Config) *UploadHandler {
h := &UploadHandler{cfg: cfg}
if cfg.R2AccountID != "" && cfg.R2AccessKeyID != "" {
// 初始化 R2 客户端
r2Resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{
URL: fmt.Sprintf("https://%s.r2.cloudflarestorage.com", cfg.R2AccountID),
}, nil
})
awsCfg, err := awsconfig.LoadDefaultConfig(context.TODO(),
awsconfig.WithEndpointResolverWithOptions(r2Resolver),
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.R2AccessKeyID,
cfg.R2AccessKeySecret,
"",
)),
awsconfig.WithRegion("auto"),
)
if err == nil {
h.s3Client = s3.NewFromConfig(awsCfg)
}
}
return h
}
func (h *UploadHandler) Upload(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "no file uploaded"})
return
}
// 检查文件类型
ext := strings.ToLower(filepath.Ext(file.Filename))
allowedExts := map[string]string{
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".mp4": "video/mp4",
".mov": "video/quicktime",
}
contentType, ok := allowedExts[ext]
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported file type"})
return
}
// 检查文件大小 (图片最大 50MB, 视频最大 100MB)
maxSize := int64(50 * 1024 * 1024) // 默认 50MB
if contentType == "video/mp4" || contentType == "video/quicktime" {
maxSize = 100 * 1024 * 1024 // 视频 100MB
}
if file.Size > maxSize {
if contentType == "video/mp4" || contentType == "video/quicktime" {
c.JSON(http.StatusBadRequest, gin.H{"error": "video too large (max 100MB)"})
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "file too large (max 50MB)"})
}
return
}
// 生成唯一文件名
filename := fmt.Sprintf("%s/%s%s",
time.Now().Format("2006/01"),
uuid.New().String(),
ext,
)
// 打开文件
src, err := file.Open()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read file"})
return
}
defer src.Close()
// 优先使用 R2 存储
if h.s3Client != nil {
url, err := h.uploadToR2(src, filename, contentType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to upload file: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"url": url,
"filename": filename,
})
return
}
// 使用本地存储
if h.cfg.LocalUploadPath != "" {
url, err := h.saveLocal(src, filename)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save file"})
return
}
c.JSON(http.StatusOK, gin.H{
"url": url,
"filename": filename,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage not configured"})
}
func (h *UploadHandler) saveLocal(src io.Reader, filename string) (string, error) {
// 创建目录
fullPath := filepath.Join(h.cfg.LocalUploadPath, filename)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", err
}
// 创建文件
dst, err := os.Create(fullPath)
if err != nil {
return "", err
}
defer dst.Close()
// 复制内容
if _, err := io.Copy(dst, src); err != nil {
return "", err
}
// 返回 URL
url := fmt.Sprintf("%s/uploads/%s", strings.TrimSuffix(h.cfg.BaseURL, "/"), filename)
return url, nil
}
func (h *UploadHandler) uploadToR2(src io.Reader, filename, contentType string) (string, error) {
_, err := h.s3Client.PutObject(context.TODO(), &s3.PutObjectInput{
Bucket: aws.String(h.cfg.R2BucketName),
Key: aws.String(filename),
Body: src,
ContentType: aws.String(contentType),
})
if err != nil {
return "", err
}
// 如果配置了公开 URL直接返回
if h.cfg.R2PublicURL != "" {
return fmt.Sprintf("%s/%s", strings.TrimSuffix(h.cfg.R2PublicURL, "/"), filename), nil
}
// 否则通过服务端代理访问
return fmt.Sprintf("%s/files/%s", strings.TrimSuffix(h.cfg.BaseURL, "/"), filename), nil
}
// GetFile 代理访问 R2 文件
func (h *UploadHandler) GetFile(c *gin.Context) {
// 获取文件路径 (格式: 2024/01/uuid.jpg)
filepath := c.Param("filepath")
if filepath == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "filepath required"})
return
}
// 从 R2 获取文件
if h.s3Client != nil {
result, err := h.s3Client.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(h.cfg.R2BucketName),
Key: aws.String(filepath),
})
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
defer result.Body.Close()
// 设置响应头
if result.ContentType != nil {
c.Header("Content-Type", *result.ContentType)
}
c.Header("Cache-Control", "public, max-age=31536000")
// 流式传输
c.Status(http.StatusOK)
io.Copy(c.Writer, result.Body)
return
}
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
}