208 lines
8.6 KiB
TypeScript
208 lines
8.6 KiB
TypeScript
import { eq } from 'drizzle-orm';
|
|
import { db } from '../db.js';
|
|
import { recommendations } from '../db/schema.js';
|
|
import { runInterpreter } from '../agents/interpreter.js';
|
|
import { runRetrieval } from '../agents/retrieval.js';
|
|
import { runRanking } from '../agents/ranking.js';
|
|
import { runCurator } from '../agents/curator.js';
|
|
import type { CuratorOutput, MediaType, RankingOutput, RetrievalCandidate, SSEEvent } from '../types/agents.js';
|
|
import { generateTitle } from '../agents/titleGenerator.js';
|
|
|
|
/* -- Agent pipeline --
|
|
[1] Interpreter -> gets user input, transforms into structured data
|
|
[2] Retrieval -> gets candidates from OpenAI (high temperature)
|
|
[3] Ranking -> ranks candidates based on user input
|
|
[4] Curator -> curates candidates based on user input
|
|
*/
|
|
|
|
type RecommendationRecord = typeof recommendations.$inferSelect;
|
|
|
|
function getBucketCount(count: number): number {
|
|
if (count <= 50) return 1;
|
|
if (count <= 100) return 2;
|
|
if (count <= 150) return 3;
|
|
return 4;
|
|
}
|
|
|
|
function deduplicateCandidates(candidates: RetrievalCandidate[]): RetrievalCandidate[] {
|
|
const seen = new Set<string>();
|
|
return candidates.filter((c) => {
|
|
const key = c.title.toLowerCase();
|
|
if (seen.has(key)) return false;
|
|
seen.add(key);
|
|
return true;
|
|
});
|
|
}
|
|
|
|
function splitIntoBuckets<T>(items: T[], n: number): T[][] {
|
|
const size = Math.ceil(items.length / n);
|
|
return Array.from({ length: n }, (_, i) => items.slice(i * size, (i + 1) * size))
|
|
.filter((b) => b.length > 0);
|
|
}
|
|
|
|
function log(recId: string, msg: string, data?: unknown) {
|
|
const ts = new Date().toISOString();
|
|
if (data !== undefined) {
|
|
console.log(`[pipeline] [${ts}] [${recId}] ${msg}`, data);
|
|
} else {
|
|
console.log(`[pipeline] [${ts}] [${recId}] ${msg}`);
|
|
}
|
|
}
|
|
|
|
export async function runPipeline(
|
|
rec: RecommendationRecord,
|
|
sseWrite: (event: SSEEvent) => void,
|
|
feedbackContext?: string,
|
|
): Promise<CuratorOutput[]> {
|
|
let currentStage: SSEEvent['stage'] = 'interpreter';
|
|
const startTime = Date.now();
|
|
const mediaType = (rec.media_type ?? 'tv_show') as MediaType;
|
|
const useWebSearch = rec.use_web_search ?? false;
|
|
|
|
log(rec.id, `Starting pipeline for "${rec.title}" [${mediaType}${useWebSearch ? ', web_search' : ''}]${feedbackContext ? ' (with feedback context)' : ''}`);
|
|
|
|
try {
|
|
// Set status to running
|
|
log(rec.id, 'Setting status → running');
|
|
await db
|
|
.update(recommendations)
|
|
.set({ status: 'running' })
|
|
.where(eq(recommendations.id, rec.id));
|
|
|
|
// --- Interpreter ---
|
|
currentStage = 'interpreter';
|
|
log(rec.id, 'Interpreter: start');
|
|
sseWrite({ stage: 'interpreter', status: 'start' });
|
|
const t0 = Date.now();
|
|
const interpreterOutput = await runInterpreter({
|
|
main_prompt: rec.main_prompt,
|
|
liked_shows: rec.liked_shows,
|
|
disliked_shows: rec.disliked_shows,
|
|
themes: rec.themes,
|
|
media_type: mediaType,
|
|
...(feedbackContext !== undefined ? { feedback_context: feedbackContext } : {}),
|
|
});
|
|
log(rec.id, `Interpreter: done (${Date.now() - t0}ms)`, {
|
|
liked: interpreterOutput.liked,
|
|
disliked: interpreterOutput.disliked,
|
|
themes: interpreterOutput.themes,
|
|
tone: interpreterOutput.tone,
|
|
avoid: interpreterOutput.avoid,
|
|
});
|
|
sseWrite({ stage: 'interpreter', status: 'done', data: interpreterOutput });
|
|
|
|
// --- Retrieval (bucketed) ---
|
|
currentStage = 'retrieval';
|
|
log(rec.id, 'Retrieval: start');
|
|
sseWrite({ stage: 'retrieval', status: 'start' });
|
|
const t1 = Date.now();
|
|
const retrievalBucketCount = getBucketCount(rec.brainstorm_count);
|
|
const perBucketCount = Math.ceil(rec.brainstorm_count / retrievalBucketCount);
|
|
const retrievalBuckets = await Promise.all(
|
|
Array.from({ length: retrievalBucketCount }, () =>
|
|
runRetrieval(interpreterOutput, perBucketCount, mediaType, useWebSearch)
|
|
)
|
|
);
|
|
const allCandidates = retrievalBuckets.flatMap((r) => r.candidates);
|
|
const dedupedCandidates = deduplicateCandidates(allCandidates);
|
|
const retrievalOutput = { candidates: dedupedCandidates };
|
|
log(rec.id, `Retrieval: done (${Date.now() - t1}ms) — ${dedupedCandidates.length} candidates (${retrievalBucketCount} buckets, ${allCandidates.length} before dedup)`, {
|
|
titles: dedupedCandidates.map((c) => c.title),
|
|
});
|
|
sseWrite({ stage: 'retrieval', status: 'done', data: retrievalOutput });
|
|
|
|
// --- Ranking (bucketed) ---
|
|
currentStage = 'ranking';
|
|
log(rec.id, 'Ranking: start');
|
|
sseWrite({ stage: 'ranking', status: 'start' });
|
|
const t2 = Date.now();
|
|
const rankBucketCount = getBucketCount(dedupedCandidates.length);
|
|
const candidateBuckets = splitIntoBuckets(dedupedCandidates, rankBucketCount);
|
|
const rankingBuckets = await Promise.all(
|
|
candidateBuckets.map((bucket) =>
|
|
runRanking(interpreterOutput, { candidates: bucket }, mediaType)
|
|
)
|
|
);
|
|
const rankingOutput: RankingOutput = {
|
|
full_match: rankingBuckets.flatMap((r) => r.full_match),
|
|
definitely_like: rankingBuckets.flatMap((r) => r.definitely_like),
|
|
might_like: rankingBuckets.flatMap((r) => r.might_like),
|
|
questionable: rankingBuckets.flatMap((r) => r.questionable),
|
|
will_not_like: rankingBuckets.flatMap((r) => r.will_not_like),
|
|
};
|
|
log(rec.id, `Ranking: done (${Date.now() - t2}ms) — ${rankBucketCount} buckets`, {
|
|
full_match: rankingOutput.full_match.length,
|
|
definitely_like: rankingOutput.definitely_like.length,
|
|
might_like: rankingOutput.might_like.length,
|
|
questionable: rankingOutput.questionable.length,
|
|
will_not_like: rankingOutput.will_not_like.length,
|
|
});
|
|
sseWrite({ stage: 'ranking', status: 'done', data: rankingOutput });
|
|
|
|
// --- Curator (bucketed) ---
|
|
currentStage = 'curator';
|
|
log(rec.id, 'Curator: start');
|
|
sseWrite({ stage: 'curator', status: 'start' });
|
|
const t3 = Date.now();
|
|
type CategorizedItem = { title: string; category: keyof RankingOutput };
|
|
const categorizedItems: CategorizedItem[] = [
|
|
...rankingOutput.full_match.map((t) => ({ title: t, category: 'full_match' as const })),
|
|
...rankingOutput.definitely_like.map((t) => ({ title: t, category: 'definitely_like' as const })),
|
|
...rankingOutput.might_like.map((t) => ({ title: t, category: 'might_like' as const })),
|
|
...rankingOutput.questionable.map((t) => ({ title: t, category: 'questionable' as const })),
|
|
...rankingOutput.will_not_like.map((t) => ({ title: t, category: 'will_not_like' as const })),
|
|
];
|
|
const curatorBucketCount = getBucketCount(categorizedItems.length);
|
|
const curatorItemBuckets = splitIntoBuckets(categorizedItems, curatorBucketCount);
|
|
const curatorBucketRankings: RankingOutput[] = curatorItemBuckets.map((bucket) => ({
|
|
full_match: bucket.filter((i) => i.category === 'full_match').map((i) => i.title),
|
|
definitely_like: bucket.filter((i) => i.category === 'definitely_like').map((i) => i.title),
|
|
might_like: bucket.filter((i) => i.category === 'might_like').map((i) => i.title),
|
|
questionable: bucket.filter((i) => i.category === 'questionable').map((i) => i.title),
|
|
will_not_like: bucket.filter((i) => i.category === 'will_not_like').map((i) => i.title),
|
|
}));
|
|
const curatorBucketOutputs = await Promise.all(
|
|
curatorBucketRankings.map((ranking) =>
|
|
runCurator(ranking, interpreterOutput, mediaType, useWebSearch)
|
|
)
|
|
);
|
|
const curatorOutput = curatorBucketOutputs.flat();
|
|
log(rec.id, `Curator: done (${Date.now() - t3}ms) — ${curatorOutput.length} items curated (${curatorBucketCount} buckets)`);
|
|
sseWrite({ stage: 'curator', status: 'done', data: curatorOutput });
|
|
|
|
// Generate AI title
|
|
let aiTitle: string = rec.title;
|
|
try {
|
|
log(rec.id, 'Title generation: start');
|
|
aiTitle = await generateTitle(interpreterOutput, mediaType);
|
|
log(rec.id, `Title generation: done — "${aiTitle}"`);
|
|
} catch (err) {
|
|
log(rec.id, `Title generation failed, keeping initial title: ${String(err)}`);
|
|
}
|
|
|
|
// Save results to DB
|
|
log(rec.id, 'Saving results to DB');
|
|
await db
|
|
.update(recommendations)
|
|
.set({ recommendations: curatorOutput, status: 'done', title: aiTitle })
|
|
.where(eq(recommendations.id, rec.id));
|
|
|
|
sseWrite({ stage: 'complete', status: 'done' });
|
|
|
|
log(rec.id, `Pipeline complete (total: ${Date.now() - startTime}ms)`);
|
|
return curatorOutput;
|
|
} catch (err) {
|
|
const message = err instanceof Error ? err.message : String(err);
|
|
log(rec.id, `Pipeline error at stage "${currentStage}": ${message}`);
|
|
|
|
sseWrite({ stage: currentStage, status: 'error', data: { message } });
|
|
|
|
await db
|
|
.update(recommendations)
|
|
.set({ status: 'error' })
|
|
.where(eq(recommendations.id, rec.id));
|
|
|
|
return [];
|
|
}
|
|
}
|