221 lines
5.4 KiB
Go
221 lines
5.4 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"memory/internal/config"
|
||
|
||
"github.com/aws/aws-sdk-go-v2/aws"
|
||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
type UploadHandler struct {
|
||
cfg *config.Config
|
||
s3Client *s3.Client
|
||
}
|
||
|
||
func NewUploadHandler(cfg *config.Config) *UploadHandler {
|
||
h := &UploadHandler{cfg: cfg}
|
||
|
||
if cfg.R2AccountID != "" && cfg.R2AccessKeyID != "" {
|
||
// 初始化 R2 客户端
|
||
r2Resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
|
||
return aws.Endpoint{
|
||
URL: fmt.Sprintf("https://%s.r2.cloudflarestorage.com", cfg.R2AccountID),
|
||
}, nil
|
||
})
|
||
|
||
awsCfg, err := awsconfig.LoadDefaultConfig(context.TODO(),
|
||
awsconfig.WithEndpointResolverWithOptions(r2Resolver),
|
||
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||
cfg.R2AccessKeyID,
|
||
cfg.R2AccessKeySecret,
|
||
"",
|
||
)),
|
||
awsconfig.WithRegion("auto"),
|
||
)
|
||
if err == nil {
|
||
h.s3Client = s3.NewFromConfig(awsCfg)
|
||
}
|
||
}
|
||
|
||
return h
|
||
}
|
||
|
||
func (h *UploadHandler) Upload(c *gin.Context) {
|
||
file, err := c.FormFile("file")
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "no file uploaded"})
|
||
return
|
||
}
|
||
|
||
// 检查文件类型
|
||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||
allowedExts := map[string]string{
|
||
".jpg": "image/jpeg",
|
||
".jpeg": "image/jpeg",
|
||
".png": "image/png",
|
||
".gif": "image/gif",
|
||
".webp": "image/webp",
|
||
".mp4": "video/mp4",
|
||
".mov": "video/quicktime",
|
||
}
|
||
|
||
contentType, ok := allowedExts[ext]
|
||
if !ok {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported file type"})
|
||
return
|
||
}
|
||
|
||
// 检查文件大小 (图片最大 50MB, 视频最大 100MB)
|
||
maxSize := int64(50 * 1024 * 1024) // 默认 50MB
|
||
if contentType == "video/mp4" || contentType == "video/quicktime" {
|
||
maxSize = 100 * 1024 * 1024 // 视频 100MB
|
||
}
|
||
if file.Size > maxSize {
|
||
if contentType == "video/mp4" || contentType == "video/quicktime" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "video too large (max 100MB)"})
|
||
} else {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "file too large (max 50MB)"})
|
||
}
|
||
return
|
||
}
|
||
|
||
// 生成唯一文件名
|
||
filename := fmt.Sprintf("%s/%s%s",
|
||
time.Now().Format("2006/01"),
|
||
uuid.New().String(),
|
||
ext,
|
||
)
|
||
|
||
// 打开文件
|
||
src, err := file.Open()
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read file"})
|
||
return
|
||
}
|
||
defer src.Close()
|
||
|
||
// 优先使用 R2 存储
|
||
if h.s3Client != nil {
|
||
url, err := h.uploadToR2(src, filename, contentType)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to upload file: " + err.Error()})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"url": url,
|
||
"filename": filename,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 使用本地存储
|
||
if h.cfg.LocalUploadPath != "" {
|
||
url, err := h.saveLocal(src, filename)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save file"})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"url": url,
|
||
"filename": filename,
|
||
})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage not configured"})
|
||
}
|
||
|
||
func (h *UploadHandler) saveLocal(src io.Reader, filename string) (string, error) {
|
||
// 创建目录
|
||
fullPath := filepath.Join(h.cfg.LocalUploadPath, filename)
|
||
dir := filepath.Dir(fullPath)
|
||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 创建文件
|
||
dst, err := os.Create(fullPath)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer dst.Close()
|
||
|
||
// 复制内容
|
||
if _, err := io.Copy(dst, src); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 返回 URL
|
||
url := fmt.Sprintf("%s/uploads/%s", strings.TrimSuffix(h.cfg.BaseURL, "/"), filename)
|
||
return url, nil
|
||
}
|
||
|
||
func (h *UploadHandler) uploadToR2(src io.Reader, filename, contentType string) (string, error) {
|
||
_, err := h.s3Client.PutObject(context.TODO(), &s3.PutObjectInput{
|
||
Bucket: aws.String(h.cfg.R2BucketName),
|
||
Key: aws.String(filename),
|
||
Body: src,
|
||
ContentType: aws.String(contentType),
|
||
})
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 如果配置了公开 URL,直接返回
|
||
if h.cfg.R2PublicURL != "" {
|
||
return fmt.Sprintf("%s/%s", strings.TrimSuffix(h.cfg.R2PublicURL, "/"), filename), nil
|
||
}
|
||
|
||
// 否则通过服务端代理访问
|
||
return fmt.Sprintf("%s/files/%s", strings.TrimSuffix(h.cfg.BaseURL, "/"), filename), nil
|
||
}
|
||
|
||
// GetFile 代理访问 R2 文件
|
||
func (h *UploadHandler) GetFile(c *gin.Context) {
|
||
// 获取文件路径 (格式: 2024/01/uuid.jpg)
|
||
filepath := c.Param("filepath")
|
||
if filepath == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "filepath required"})
|
||
return
|
||
}
|
||
|
||
// 从 R2 获取文件
|
||
if h.s3Client != nil {
|
||
result, err := h.s3Client.GetObject(context.TODO(), &s3.GetObjectInput{
|
||
Bucket: aws.String(h.cfg.R2BucketName),
|
||
Key: aws.String(filepath),
|
||
})
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||
return
|
||
}
|
||
defer result.Body.Close()
|
||
|
||
// 设置响应头
|
||
if result.ContentType != nil {
|
||
c.Header("Content-Type", *result.ContentType)
|
||
}
|
||
c.Header("Cache-Control", "public, max-age=31536000")
|
||
|
||
// 流式传输
|
||
c.Status(http.StatusOK)
|
||
io.Copy(c.Writer, result.Body)
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||
}
|