Initial AI media hub implementation
Some checks failed
build-push / docker (push) Has been cancelled
Some checks failed
build-push / docker (push) Has been cancelled
This commit is contained in:
282
backend/handlers/api.go
Normal file
282
backend/handlers/api.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ai-media-hub/backend/models"
|
||||
"ai-media-hub/backend/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
DB *sql.DB
|
||||
DownloadsDir string
|
||||
WorkerScript string
|
||||
SearchService *services.SearchService
|
||||
GeminiService *services.GeminiService
|
||||
Hub *Hub
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
clients map[*websocket.Conn]bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
return &Hub{clients: map[*websocket.Conn]bool{}}
|
||||
}
|
||||
|
||||
func (h *Hub) Broadcast(event string, data any) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
payload, _ := json.Marshal(gin.H{"event": event, "data": data})
|
||||
for conn := range h.clients {
|
||||
_ = conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Add(conn *websocket.Conn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.clients[conn] = true
|
||||
}
|
||||
|
||||
func (h *Hub) Remove(conn *websocket.Conn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.clients, conn)
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func RegisterRoutes(router *gin.Engine, app *App) {
|
||||
router.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
router.GET("/ws", app.handleWS)
|
||||
router.GET("/api/history/check", app.checkDuplicate)
|
||||
router.POST("/api/upload", app.uploadFile)
|
||||
router.POST("/api/download", app.startDownload)
|
||||
router.POST("/api/search", app.searchMedia)
|
||||
}
|
||||
|
||||
func (a *App) handleWS(c *gin.Context) {
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
a.Hub.Add(conn)
|
||||
defer a.Hub.Remove(conn)
|
||||
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) checkDuplicate(c *gin.Context) {
|
||||
url := strings.TrimSpace(c.Query("url"))
|
||||
if url == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
||||
return
|
||||
}
|
||||
record, err := models.FindByURL(a.DB, url)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"exists": record != nil, "record": record})
|
||||
}
|
||||
|
||||
func (a *App) uploadFile(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file is required"})
|
||||
return
|
||||
}
|
||||
|
||||
a.Hub.Broadcast("progress", gin.H{"type": "upload", "status": "started", "progress": 5, "filename": file.Filename})
|
||||
|
||||
safeName := normalizeFilename(file.Filename)
|
||||
targetPath := filepath.Join(a.DownloadsDir, safeName)
|
||||
if err := c.SaveUploadedFile(file, targetPath); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
a.Hub.Broadcast("progress", gin.H{"type": "upload", "status": "completed", "progress": 100, "filename": safeName})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "uploaded", "path": targetPath, "filename": safeName})
|
||||
}
|
||||
|
||||
func (a *App) startDownload(c *gin.Context) {
|
||||
var req struct {
|
||||
URL string `json:"url"`
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rec, err := models.FindByURL(a.DB, req.URL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if rec != nil && !req.Force {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "duplicate url", "record": rec})
|
||||
return
|
||||
}
|
||||
|
||||
outputBase := uuid.NewString()
|
||||
outputPath := filepath.Join(a.DownloadsDir, outputBase+".mp4")
|
||||
recordID, err := models.InsertDownload(a.DB, req.URL, detectSource(req.URL), outputPath, "queued")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
go a.runDownload(recordID, req.URL, req.Start, req.End, outputPath)
|
||||
c.JSON(http.StatusAccepted, gin.H{"message": "download started", "recordId": recordID})
|
||||
}
|
||||
|
||||
func (a *App) runDownload(recordID int64, url, start, end, outputPath string) {
|
||||
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "queued", "progress": 0, "url": url})
|
||||
cmd := exec.Command("python3", a.WorkerScript, "--url", url, "--start", start, "--end", end, "--output", outputPath)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 0, "message": err.Error()})
|
||||
_ = models.MarkDownloadCompleted(a.DB, recordID, "failed")
|
||||
return
|
||||
}
|
||||
cmd.Stderr = cmd.Stdout
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 0, "message": err.Error()})
|
||||
_ = models.MarkDownloadCompleted(a.DB, recordID, "failed")
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(line, &msg); err == nil {
|
||||
msg["type"] = "download"
|
||||
a.Hub.Broadcast("progress", msg)
|
||||
}
|
||||
}
|
||||
|
||||
status := "completed"
|
||||
if err := cmd.Wait(); err != nil {
|
||||
status = "failed"
|
||||
a.Hub.Broadcast("progress", gin.H{"type": "download", "status": "error", "progress": 100, "message": err.Error()})
|
||||
}
|
||||
_ = models.MarkDownloadCompleted(a.DB, recordID, status)
|
||||
}
|
||||
|
||||
func (a *App) searchMedia(c *gin.Context) {
|
||||
var req struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "query is required"})
|
||||
return
|
||||
}
|
||||
|
||||
results, err := a.SearchService.SearchMedia(req.Query)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
recommended, err := a.GeminiService.Recommend(req.Query, results)
|
||||
if err != nil {
|
||||
fallback := make([]services.AIRecommendation, 0, min(4, len(results)))
|
||||
for _, result := range results[:min(4, len(results))] {
|
||||
fallback = append(fallback, services.AIRecommendation{
|
||||
Title: result.Title,
|
||||
Link: result.Link,
|
||||
ThumbnailURL: result.ThumbnailURL,
|
||||
Source: result.Source,
|
||||
Reason: "Gemini recommendation failed, showing raw search result.",
|
||||
Recommended: true,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"results": fallback, "warning": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"results": recommended})
|
||||
}
|
||||
|
||||
func normalizeFilename(name string) string {
|
||||
base := strings.ToLower(strings.TrimSpace(name))
|
||||
ext := filepath.Ext(base)
|
||||
base = strings.TrimSuffix(base, ext)
|
||||
re := regexp.MustCompile(`[^a-z0-9]+`)
|
||||
base = strings.Trim(re.ReplaceAllString(base, "-"), "-")
|
||||
if base == "" {
|
||||
base = fmt.Sprintf("upload-%d", time.Now().Unix())
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
return base + ext
|
||||
}
|
||||
|
||||
func detectSource(url string) string {
|
||||
switch {
|
||||
case strings.Contains(url, "youtube"):
|
||||
return "YouTube"
|
||||
case strings.Contains(url, "tiktok"):
|
||||
return "TikTok"
|
||||
default:
|
||||
return "direct"
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func EnsurePaths(downloadsDir, workerScript string) error {
|
||||
if err := os.MkdirAll(downloadsDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := os.Stat(workerScript); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("worker script not found: %s", workerScript)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
66
backend/main.go
Normal file
66
backend/main.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"ai-media-hub/backend/handlers"
|
||||
"ai-media-hub/backend/models"
|
||||
"ai-media-hub/backend/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
root := envOrDefault("APP_ROOT", "/app")
|
||||
dbPath := envOrDefault("SQLITE_PATH", filepath.Join(root, "db", "media.db"))
|
||||
downloadsDir := envOrDefault("DOWNLOADS_DIR", filepath.Join(root, "downloads"))
|
||||
frontendDir := envOrDefault("FRONTEND_DIR", filepath.Join(root, "frontend"))
|
||||
workerScript := envOrDefault("WORKER_SCRIPT", filepath.Join(root, "worker", "downloader.py"))
|
||||
|
||||
db, err := models.InitDB(dbPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := handlers.EnsurePaths(downloadsDir, workerScript); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
app := &handlers.App{
|
||||
DB: db,
|
||||
DownloadsDir: downloadsDir,
|
||||
WorkerScript: workerScript,
|
||||
SearchService: services.NewSearchService(os.Getenv("GOOGLE_CSE_API_KEY"), os.Getenv("GOOGLE_CSE_CX")),
|
||||
GeminiService: services.NewGeminiService(os.Getenv("GEMINI_API_KEY")),
|
||||
Hub: handlers.NewHub(),
|
||||
}
|
||||
|
||||
router := gin.Default()
|
||||
handlers.RegisterRoutes(router, app)
|
||||
router.StaticFile("/", filepath.Join(frontendDir, "index.html"))
|
||||
router.StaticFile("/app.js", filepath.Join(frontendDir, "app.js"))
|
||||
router.StaticFile("/style.css", filepath.Join(frontendDir, "style.css"))
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
c.File(filepath.Join(frontendDir, "index.html"))
|
||||
})
|
||||
router.NoMethod(func(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{"error": "method not allowed"})
|
||||
})
|
||||
|
||||
addr := envOrDefault("APP_ADDR", ":8080")
|
||||
log.Printf("server listening on %s", addr)
|
||||
if err := router.Run(addr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func envOrDefault(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
92
backend/models/db.go
Normal file
92
backend/models/db.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type DownloadRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
URL string `json:"url"`
|
||||
Source string `json:"source"`
|
||||
OutputPath string `json:"outputPath"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
CompletedAt time.Time `json:"completedAt,omitempty"`
|
||||
}
|
||||
|
||||
func InitDB(path string) (*sql.DB, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS download_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
url TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
output_path TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
completed_at DATETIME
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_download_history_url ON download_history(url);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func InsertDownload(db *sql.DB, url, source, outputPath, status string) (int64, error) {
|
||||
res, err := db.Exec(
|
||||
`INSERT INTO download_history (url, source, output_path, status) VALUES (?, ?, ?, ?)`,
|
||||
url, source, outputPath, status,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
func MarkDownloadCompleted(db *sql.DB, id int64, status string) error {
|
||||
_, err := db.Exec(
|
||||
`UPDATE download_history SET status = ?, completed_at = CURRENT_TIMESTAMP WHERE id = ?`,
|
||||
status, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func FindByURL(db *sql.DB, url string) (*DownloadRecord, error) {
|
||||
row := db.QueryRow(
|
||||
`SELECT id, url, source, output_path, status, started_at, COALESCE(completed_at, '') FROM download_history WHERE url = ? ORDER BY id DESC LIMIT 1`,
|
||||
url,
|
||||
)
|
||||
|
||||
var rec DownloadRecord
|
||||
var completedRaw string
|
||||
if err := row.Scan(&rec.ID, &rec.URL, &rec.Source, &rec.OutputPath, &rec.Status, &rec.StartedAt, &completedRaw); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if completedRaw != "" {
|
||||
parsed, err := time.Parse("2006-01-02 15:04:05", completedRaw)
|
||||
if err == nil {
|
||||
rec.CompletedAt = parsed
|
||||
}
|
||||
}
|
||||
return &rec, nil
|
||||
}
|
||||
116
backend/services/cse.go
Normal file
116
backend/services/cse.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"`
|
||||
Link string `json:"link"`
|
||||
DisplayLink string `json:"displayLink"`
|
||||
Snippet string `json:"snippet"`
|
||||
ThumbnailURL string `json:"thumbnailUrl"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
type SearchService struct {
|
||||
APIKey string
|
||||
CX string
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
func NewSearchService(apiKey, cx string) *SearchService {
|
||||
return &SearchService{
|
||||
APIKey: apiKey,
|
||||
CX: cx,
|
||||
Client: &http.Client{Timeout: 20 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SearchService) SearchMedia(query string) ([]SearchResult, error) {
|
||||
if s.APIKey == "" || s.CX == "" {
|
||||
return nil, fmt.Errorf("google cse credentials are not configured")
|
||||
}
|
||||
|
||||
domains := []string{"youtube.com", "tiktok.com", "envato.com", "artgrid.io"}
|
||||
siteQuery := strings.Join(domains, " OR site:")
|
||||
fullQuery := fmt.Sprintf("%s (site:%s)", query, siteQuery)
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("key", s.APIKey)
|
||||
values.Set("cx", s.CX)
|
||||
values.Set("q", fullQuery)
|
||||
values.Set("searchType", "image")
|
||||
values.Set("num", "10")
|
||||
values.Set("safe", "off")
|
||||
|
||||
results := make([]SearchResult, 0, 30)
|
||||
seen := map[string]bool{}
|
||||
for _, start := range []string{"1", "11", "21"} {
|
||||
values.Set("start", start)
|
||||
endpoint := "https://www.googleapis.com/customsearch/v1?" + values.Encode()
|
||||
resp, err := s.Client.Get(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("google cse returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Items []struct {
|
||||
Title string `json:"title"`
|
||||
Link string `json:"link"`
|
||||
DisplayLink string `json:"displayLink"`
|
||||
Snippet string `json:"snippet"`
|
||||
Image struct {
|
||||
ThumbnailLink string `json:"thumbnailLink"`
|
||||
} `json:"image"`
|
||||
} `json:"items"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
for _, item := range payload.Items {
|
||||
if item.Link == "" || seen[item.Link] {
|
||||
continue
|
||||
}
|
||||
seen[item.Link] = true
|
||||
results = append(results, SearchResult{
|
||||
Title: item.Title,
|
||||
Link: item.Link,
|
||||
DisplayLink: item.DisplayLink,
|
||||
Snippet: item.Snippet,
|
||||
ThumbnailURL: item.Image.ThumbnailLink,
|
||||
Source: inferSource(item.DisplayLink),
|
||||
})
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func inferSource(displayLink string) string {
|
||||
switch {
|
||||
case strings.Contains(displayLink, "youtube"):
|
||||
return "YouTube"
|
||||
case strings.Contains(displayLink, "tiktok"):
|
||||
return "TikTok"
|
||||
case strings.Contains(displayLink, "envato"):
|
||||
return "Envato"
|
||||
case strings.Contains(displayLink, "artgrid"):
|
||||
return "Artgrid"
|
||||
default:
|
||||
return displayLink
|
||||
}
|
||||
}
|
||||
175
backend/services/gemini.go
Normal file
175
backend/services/gemini.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GeminiService struct {
|
||||
APIKey string
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
type AIRecommendation struct {
|
||||
Title string `json:"title"`
|
||||
Link string `json:"link"`
|
||||
ThumbnailURL string `json:"thumbnailUrl"`
|
||||
Source string `json:"source"`
|
||||
Reason string `json:"reason"`
|
||||
Recommended bool `json:"recommended"`
|
||||
}
|
||||
|
||||
func NewGeminiService(apiKey string) *GeminiService {
|
||||
return &GeminiService{
|
||||
APIKey: apiKey,
|
||||
Client: &http.Client{Timeout: 40 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GeminiService) Recommend(query string, candidates []SearchResult) ([]AIRecommendation, error) {
|
||||
if g.APIKey == "" {
|
||||
return nil, fmt.Errorf("gemini api key is not configured")
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return []AIRecommendation{}, nil
|
||||
}
|
||||
|
||||
type geminiPart map[string]any
|
||||
parts := []geminiPart{
|
||||
{
|
||||
"text": `Analyze the provided images for the user's search intent. Return JSON only in this shape:
|
||||
{"recommendations":[{"index":0,"reason":"short reason","recommended":true}]}
|
||||
Mark only the best matches as recommended=true. Keep reasons concise. User query: ` + query,
|
||||
},
|
||||
}
|
||||
|
||||
maxImages := min(len(candidates), 8)
|
||||
for idx := 0; idx < maxImages; idx++ {
|
||||
img, mimeType, err := fetchImageAsInlineData(g.Client, candidates[idx].ThumbnailURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
parts = append(parts,
|
||||
geminiPart{"text": fmt.Sprintf("Candidate %d: title=%s source=%s link=%s", idx, candidates[idx].Title, candidates[idx].Source, candidates[idx].Link)},
|
||||
geminiPart{"inlineData": map[string]string{"mimeType": mimeType, "data": img}},
|
||||
)
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"parts": parts},
|
||||
},
|
||||
"generationConfig": map[string]any{
|
||||
"responseMimeType": "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, _ := json.Marshal(body)
|
||||
endpoint := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key=" + g.APIKey
|
||||
resp, err := g.Client.Post(endpoint, "application/json", bytes.NewReader(rawBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
data, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("gemini returned status %d: %s", resp.StatusCode, string(data))
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(payload.Candidates) == 0 || len(payload.Candidates[0].Content.Parts) == 0 {
|
||||
return nil, fmt.Errorf("gemini returned no candidates")
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Recommendations []struct {
|
||||
Index int `json:"index"`
|
||||
Reason string `json:"reason"`
|
||||
Recommended bool `json:"recommended"`
|
||||
} `json:"recommendations"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(payload.Candidates[0].Content.Parts[0].Text), &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recommendations := make([]AIRecommendation, 0, len(parsed.Recommendations))
|
||||
for _, rec := range parsed.Recommendations {
|
||||
if rec.Index < 0 || rec.Index >= len(candidates) || !rec.Recommended {
|
||||
continue
|
||||
}
|
||||
src := candidates[rec.Index]
|
||||
recommendations = append(recommendations, AIRecommendation{
|
||||
Title: src.Title,
|
||||
Link: src.Link,
|
||||
ThumbnailURL: src.ThumbnailURL,
|
||||
Source: src.Source,
|
||||
Reason: rec.Reason,
|
||||
Recommended: true,
|
||||
})
|
||||
}
|
||||
|
||||
if len(recommendations) == 0 {
|
||||
for _, candidate := range candidates[:min(4, len(candidates))] {
|
||||
recommendations = append(recommendations, AIRecommendation{
|
||||
Title: candidate.Title,
|
||||
Link: candidate.Link,
|
||||
ThumbnailURL: candidate.ThumbnailURL,
|
||||
Source: candidate.Source,
|
||||
Reason: "Fallback result because Gemini returned no recommended items.",
|
||||
Recommended: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return recommendations, nil
|
||||
}
|
||||
|
||||
func fetchImageAsInlineData(client *http.Client, imageURL string) (string, string, error) {
|
||||
resp, err := client.Get(imageURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("thumbnail fetch failed with %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
mimeType, _, _ := mime.ParseMediaType(contentType)
|
||||
if mimeType == "" || !strings.HasPrefix(mimeType, "image/") {
|
||||
mimeType = "image/jpeg"
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(data), mimeType, nil
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user