328 lines
13 KiB
TypeScript
328 lines
13 KiB
TypeScript
import { eq } from 'drizzle-orm';
|
||
import { db } from '../db.js';
|
||
import { recommendations } from '../db/schema.js';
|
||
import { runInterpreter } from '../agents/interpreter.js';
|
||
import { runRetrieval } from '../agents/retrieval.js';
|
||
import { runValidator } from '../agents/validator.js';
|
||
import { runRanking } from '../agents/ranking.js';
|
||
import { runCurator } from '../agents/curator.js';
|
||
import type { CuratorOutput, InterpreterOutput, MediaType, RankingOutput, RetrievalCandidate, SSEEvent } from '../types/agents.js';
|
||
import { generateTitle } from '../agents/titleGenerator.js';
|
||
|
||
/* -- Agent pipeline --
|
||
[1] Interpreter -> gets user input, transforms into structured data
|
||
[2] Retrieval -> gets candidates from OpenAI (high temperature)
|
||
[2.5] Validator (optional) -> verifies candidates exist, removes trash
|
||
[3] Ranking -> ranks candidates based on user input
|
||
[4] Curator -> curates candidates based on user input
|
||
*/
|
||
|
||
type RecommendationRecord = typeof recommendations.$inferSelect;
|
||
|
||
function getBucketCount(count: number): number {
|
||
return Math.ceil(count / 15);
|
||
}
|
||
|
||
function deduplicateCandidates(candidates: RetrievalCandidate[], seenTitles?: Set<string>): RetrievalCandidate[] {
|
||
const seen = seenTitles ?? new Set<string>();
|
||
return candidates.filter((c) => {
|
||
const key = c.title.toLowerCase();
|
||
if (seen.has(key)) return false;
|
||
seen.add(key);
|
||
return true;
|
||
});
|
||
}
|
||
|
||
function splitIntoBuckets<T>(items: T[], n: number): T[][] {
|
||
const size = Math.ceil(items.length / n);
|
||
return Array.from({ length: n }, (_, i) => items.slice(i * size, (i + 1) * size))
|
||
.filter((b) => b.length > 0);
|
||
}
|
||
|
||
function mergeCuratorOutputs(a: CuratorOutput[], b: CuratorOutput[]): CuratorOutput[] {
|
||
const seen = new Set(a.map((x) => x.title.toLowerCase()));
|
||
return [...a, ...b.filter((x) => !seen.has(x.title.toLowerCase()))];
|
||
}
|
||
|
||
function log(recId: string, msg: string, data?: unknown) {
|
||
const ts = new Date().toISOString();
|
||
if (data !== undefined) {
|
||
console.log(`[pipeline] [${ts}] [${recId}] ${msg}`, data);
|
||
} else {
|
||
console.log(`[pipeline] [${ts}] [${recId}] ${msg}`);
|
||
}
|
||
}
|
||
|
||
interface SubPipelineCtx {
|
||
recId: string;
|
||
interpreterOutput: InterpreterOutput;
|
||
mediaType: MediaType;
|
||
useWebSearch: boolean;
|
||
useValidator: boolean;
|
||
useHardRequirements: boolean;
|
||
brainstormCount: number;
|
||
previousFullMatches: string[];
|
||
allSeenTitles: Set<string>;
|
||
stagePrefix: string;
|
||
sseWrite: (event: SSEEvent) => void;
|
||
}
|
||
|
||
async function runSubPipeline(ctx: SubPipelineCtx): Promise<CuratorOutput[]> {
|
||
const {
|
||
recId, interpreterOutput, mediaType, useWebSearch, useValidator,
|
||
useHardRequirements, brainstormCount, previousFullMatches,
|
||
allSeenTitles, stagePrefix, sseWrite,
|
||
} = ctx;
|
||
|
||
const p = (stage: string) => (stagePrefix + stage) as SSEEvent['stage'];
|
||
|
||
// --- Retrieval (bucketed) ---
|
||
log(recId, `${stagePrefix}Retrieval: start`);
|
||
sseWrite({ stage: p('retrieval'), status: 'start' });
|
||
const t1 = Date.now();
|
||
const retrievalBucketCount = getBucketCount(brainstormCount);
|
||
const perBucketCount = Math.ceil(brainstormCount / retrievalBucketCount);
|
||
const retrievalBuckets = await Promise.all(
|
||
Array.from({ length: retrievalBucketCount }, () =>
|
||
runRetrieval(interpreterOutput, perBucketCount, mediaType, useWebSearch, useHardRequirements, previousFullMatches)
|
||
)
|
||
);
|
||
const allCandidates = retrievalBuckets.flatMap((r) => r.candidates);
|
||
const dedupedCandidates = deduplicateCandidates(allCandidates, allSeenTitles);
|
||
log(recId, `${stagePrefix}Retrieval: done (${Date.now() - t1}ms) — ${dedupedCandidates.length} candidates (${retrievalBucketCount} buckets, ${allCandidates.length} before dedup)`, {
|
||
titles: dedupedCandidates.map((c) => c.title),
|
||
});
|
||
sseWrite({ stage: p('retrieval'), status: 'done', data: { candidates: dedupedCandidates } });
|
||
|
||
// --- Validator (optional) ---
|
||
let candidatesForRanking = dedupedCandidates;
|
||
if (useValidator) {
|
||
log(recId, `${stagePrefix}Validator: start`);
|
||
sseWrite({ stage: p('validator'), status: 'start' });
|
||
const tV = Date.now();
|
||
const validatorOutput = await runValidator(dedupedCandidates, mediaType);
|
||
const verified = validatorOutput.candidates.filter((c) => !c.isTrash);
|
||
const trashCount = validatorOutput.candidates.length - verified.length;
|
||
candidatesForRanking = verified.map(({ title, reason }) => ({ title, reason }));
|
||
log(recId, `${stagePrefix}Validator: done (${Date.now() - tV}ms) — removed ${trashCount} trash entries`);
|
||
sseWrite({ stage: p('validator'), status: 'done', data: { removed: trashCount } });
|
||
} else {
|
||
sseWrite({ stage: p('validator'), status: 'done', data: { skipped: true } });
|
||
}
|
||
|
||
// --- Ranking (bucketed) ---
|
||
log(recId, `${stagePrefix}Ranking: start`);
|
||
sseWrite({ stage: p('ranking'), status: 'start' });
|
||
const t2 = Date.now();
|
||
const rankBucketCount = getBucketCount(candidatesForRanking.length);
|
||
const candidateBuckets = splitIntoBuckets(candidatesForRanking, rankBucketCount);
|
||
const rankingBuckets = await Promise.all(
|
||
candidateBuckets.map((bucket) =>
|
||
runRanking(interpreterOutput, { candidates: bucket }, mediaType, useHardRequirements)
|
||
)
|
||
);
|
||
const dedupTitles = (titles: string[]) => [...new Map(titles.map((t) => [t.toLowerCase(), t])).values()];
|
||
const rankingOutput: RankingOutput = {
|
||
full_match: dedupTitles(rankingBuckets.flatMap((r) => r.full_match)),
|
||
definitely_like: dedupTitles(rankingBuckets.flatMap((r) => r.definitely_like)),
|
||
might_like: dedupTitles(rankingBuckets.flatMap((r) => r.might_like)),
|
||
questionable: dedupTitles(rankingBuckets.flatMap((r) => r.questionable)),
|
||
will_not_like: dedupTitles(rankingBuckets.flatMap((r) => r.will_not_like)),
|
||
};
|
||
log(recId, `${stagePrefix}Ranking: done (${Date.now() - t2}ms) — ${rankBucketCount} buckets`, {
|
||
full_match: rankingOutput.full_match.length,
|
||
definitely_like: rankingOutput.definitely_like.length,
|
||
might_like: rankingOutput.might_like.length,
|
||
questionable: rankingOutput.questionable.length,
|
||
will_not_like: rankingOutput.will_not_like.length,
|
||
});
|
||
sseWrite({ stage: p('ranking'), status: 'done', data: rankingOutput });
|
||
|
||
// --- Curator (bucketed) ---
|
||
log(recId, `${stagePrefix}Curator: start`);
|
||
sseWrite({ stage: p('curator'), status: 'start' });
|
||
const t3 = Date.now();
|
||
type CategorizedItem = { title: string; category: keyof RankingOutput };
|
||
const categorizedItems: CategorizedItem[] = [
|
||
...rankingOutput.full_match.map((t) => ({ title: t, category: 'full_match' as const })),
|
||
...rankingOutput.definitely_like.map((t) => ({ title: t, category: 'definitely_like' as const })),
|
||
...rankingOutput.might_like.map((t) => ({ title: t, category: 'might_like' as const })),
|
||
...rankingOutput.questionable.map((t) => ({ title: t, category: 'questionable' as const })),
|
||
...rankingOutput.will_not_like.map((t) => ({ title: t, category: 'will_not_like' as const })),
|
||
];
|
||
const curatorBucketCount = getBucketCount(categorizedItems.length);
|
||
const curatorItemBuckets = splitIntoBuckets(categorizedItems, curatorBucketCount);
|
||
const curatorBucketRankings: RankingOutput[] = curatorItemBuckets.map((bucket) => ({
|
||
full_match: bucket.filter((i) => i.category === 'full_match').map((i) => i.title),
|
||
definitely_like: bucket.filter((i) => i.category === 'definitely_like').map((i) => i.title),
|
||
might_like: bucket.filter((i) => i.category === 'might_like').map((i) => i.title),
|
||
questionable: bucket.filter((i) => i.category === 'questionable').map((i) => i.title),
|
||
will_not_like: bucket.filter((i) => i.category === 'will_not_like').map((i) => i.title),
|
||
}));
|
||
const curatorBucketOutputs = await Promise.all(
|
||
curatorBucketRankings.map((ranking) =>
|
||
runCurator(ranking, interpreterOutput, mediaType, useWebSearch)
|
||
)
|
||
);
|
||
const curatorOutput = curatorBucketOutputs.reduce((acc, bucket) => mergeCuratorOutputs(acc, bucket), [] as CuratorOutput[]);
|
||
log(recId, `${stagePrefix}Curator: done (${Date.now() - t3}ms) — ${curatorOutput.length} items curated (${curatorBucketCount} buckets)`);
|
||
sseWrite({ stage: p('curator'), status: 'done', data: curatorOutput });
|
||
|
||
return curatorOutput;
|
||
}
|
||
|
||
export async function runPipeline(
|
||
rec: RecommendationRecord,
|
||
sseWrite: (event: SSEEvent) => void,
|
||
feedbackContext?: string,
|
||
): Promise<CuratorOutput[]> {
|
||
let currentStage: SSEEvent['stage'] = 'interpreter';
|
||
const startTime = Date.now();
|
||
const mediaType = (rec.media_type ?? 'tv_show') as MediaType;
|
||
const useWebSearch = rec.use_web_search ?? false;
|
||
const useValidator = rec.use_validator ?? false;
|
||
const useHardRequirements = rec.hard_requirements ?? false;
|
||
const selfExpansive = rec.self_expansive ?? false;
|
||
|
||
log(rec.id, `Starting pipeline for "${rec.title}" [${mediaType}${useWebSearch ? ', web_search' : ''}${useValidator ? ', validator' : ''}${useHardRequirements ? ', hard_req' : ''}${selfExpansive ? `, expansive×${rec.expansive_passes}(${rec.expansive_mode})` : ''}]${feedbackContext ? ' (with feedback context)' : ''}`);
|
||
|
||
try {
|
||
// Set status to running
|
||
log(rec.id, 'Setting status → running');
|
||
await db
|
||
.update(recommendations)
|
||
.set({ status: 'running' })
|
||
.where(eq(recommendations.id, rec.id));
|
||
|
||
// --- Interpreter ---
|
||
currentStage = 'interpreter';
|
||
log(rec.id, 'Interpreter: start');
|
||
sseWrite({ stage: 'interpreter', status: 'start' });
|
||
const t0 = Date.now();
|
||
const interpreterOutput = await runInterpreter({
|
||
main_prompt: rec.main_prompt,
|
||
liked_series: rec.liked_series,
|
||
disliked_series: rec.disliked_series,
|
||
themes: rec.themes,
|
||
media_type: mediaType,
|
||
...(feedbackContext !== undefined ? { feedback_context: feedbackContext } : {}),
|
||
});
|
||
log(rec.id, `Interpreter: done (${Date.now() - t0}ms)`, {
|
||
liked: interpreterOutput.liked,
|
||
disliked: interpreterOutput.disliked,
|
||
themes: interpreterOutput.themes,
|
||
tone: interpreterOutput.tone,
|
||
avoid: interpreterOutput.avoid,
|
||
});
|
||
sseWrite({ stage: 'interpreter', status: 'done', data: interpreterOutput });
|
||
|
||
// --- Pass 1: Retrieval → [Validator?] → Ranking → Curator ---
|
||
currentStage = 'retrieval';
|
||
const allSeenTitles = new Set<string>();
|
||
const pass1Output = await runSubPipeline({
|
||
recId: rec.id,
|
||
interpreterOutput,
|
||
mediaType,
|
||
useWebSearch,
|
||
useValidator,
|
||
useHardRequirements,
|
||
brainstormCount: rec.brainstorm_count,
|
||
previousFullMatches: [],
|
||
allSeenTitles,
|
||
stagePrefix: '',
|
||
sseWrite: (event) => {
|
||
currentStage = event.stage;
|
||
sseWrite(event);
|
||
},
|
||
});
|
||
|
||
let mergedOutput = pass1Output;
|
||
|
||
// --- Self Expansive: extra passes ---
|
||
if (selfExpansive && rec.expansive_passes > 0) {
|
||
const allFullMatches = pass1Output
|
||
.filter((c) => c.category === 'Full Match')
|
||
.map((c) => c.title);
|
||
|
||
for (let i = 0; i < rec.expansive_passes; i++) {
|
||
const passNum = i + 2;
|
||
const passCount = rec.expansive_mode === 'extreme' ? rec.brainstorm_count : 60;
|
||
const passPrefix = `pass${passNum}:` as const;
|
||
|
||
log(rec.id, `Self Expansive Pass ${passNum}: start (${passCount} candidates, ${allFullMatches.length} full matches as context)`);
|
||
currentStage = `${passPrefix}retrieval` as SSEEvent['stage'];
|
||
|
||
const passOutput = await runSubPipeline({
|
||
recId: rec.id,
|
||
interpreterOutput,
|
||
mediaType,
|
||
useWebSearch,
|
||
useValidator,
|
||
useHardRequirements,
|
||
brainstormCount: passCount,
|
||
previousFullMatches: [...allFullMatches],
|
||
allSeenTitles,
|
||
stagePrefix: passPrefix,
|
||
sseWrite: (event) => {
|
||
currentStage = event.stage;
|
||
sseWrite(event);
|
||
},
|
||
});
|
||
|
||
mergedOutput = mergeCuratorOutputs(mergedOutput, passOutput);
|
||
|
||
const newFullMatches = passOutput
|
||
.filter((c) => c.category === 'Full Match')
|
||
.map((c) => c.title);
|
||
allFullMatches.push(...newFullMatches);
|
||
|
||
log(rec.id, `Self Expansive Pass ${passNum}: done — ${passOutput.length} new items, ${mergedOutput.length} total`);
|
||
}
|
||
}
|
||
|
||
// Generate AI title
|
||
let aiTitle: string = rec.title;
|
||
try {
|
||
log(rec.id, 'Title generation: start');
|
||
aiTitle = await generateTitle(interpreterOutput, mediaType);
|
||
log(rec.id, `Title generation: done — "${aiTitle}"`);
|
||
} catch (err) {
|
||
log(rec.id, `Title generation failed, keeping initial title: ${String(err)}`);
|
||
}
|
||
|
||
// Sort by category order before saving
|
||
const CATEGORY_ORDER: Record<string, number> = {
|
||
'Full Match': 0,
|
||
'Definitely Like': 1,
|
||
'Might Like': 2,
|
||
'Questionable': 3,
|
||
'Will Not Like': 4,
|
||
};
|
||
mergedOutput.sort((a, b) => (CATEGORY_ORDER[a.category] ?? 99) - (CATEGORY_ORDER[b.category] ?? 99));
|
||
|
||
// Save results to DB
|
||
log(rec.id, 'Saving results to DB');
|
||
await db
|
||
.update(recommendations)
|
||
.set({ recommendations: mergedOutput, status: 'done', title: aiTitle })
|
||
.where(eq(recommendations.id, rec.id));
|
||
|
||
sseWrite({ stage: 'complete', status: 'done', data: { title: aiTitle } });
|
||
|
||
log(rec.id, `Pipeline complete (total: ${Date.now() - startTime}ms)`);
|
||
return mergedOutput;
|
||
} catch (err) {
|
||
const message = err instanceof Error ? err.message : String(err);
|
||
log(rec.id, `Pipeline error at stage "${currentStage}": ${message}`);
|
||
|
||
sseWrite({ stage: currentStage, status: 'error', data: { message } });
|
||
|
||
await db
|
||
.update(recommendations)
|
||
.set({ status: 'error' })
|
||
.where(eq(recommendations.id, rec.id));
|
||
|
||
return [];
|
||
}
|
||
}
|