JimLiu-baoyu-skills/skills/baoyu-imagine/scripts/providers/replicate.ts

605 lines
17 KiB
TypeScript

import path from "node:path";
import { readFile } from "node:fs/promises";
import type { CliArgs } from "../types";
const DEFAULT_MODEL = "google/nano-banana-2";
const SYNC_WAIT_SECONDS = 60;
const POLL_INTERVAL_MS = 2000;
const MAX_POLL_MS = 300_000;
const DOCUMENTED_REPLICATE_ASPECT_RATIOS = new Set([
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"5:4",
"4:5",
"9:16",
"16:9",
"21:9",
]);
export type ReplicateModelFamily =
| "nano-banana"
| "seedream45"
| "seedream5lite"
| "wan27image"
| "wan27imagepro"
| "unknown";
type PixelSize = {
width: number;
height: number;
};
export function getDefaultModel(): string {
return process.env.REPLICATE_IMAGE_MODEL || DEFAULT_MODEL;
}
function getApiToken(): string | null {
return process.env.REPLICATE_API_TOKEN || null;
}
function getBaseUrl(): string {
const base = process.env.REPLICATE_BASE_URL || "https://api.replicate.com";
return base.replace(/\/+$/g, "");
}
function normalizeModelId(model: string): string {
return model.trim().toLowerCase().split(":")[0]!;
}
export function getModelFamily(model: string): ReplicateModelFamily {
const normalized = normalizeModelId(model);
if (
normalized === "google/nano-banana" ||
normalized === "google/nano-banana-pro" ||
normalized === "google/nano-banana-2"
) {
return "nano-banana";
}
if (normalized === "bytedance/seedream-4.5") {
return "seedream45";
}
if (normalized === "bytedance/seedream-5-lite") {
return "seedream5lite";
}
if (normalized === "wan-video/wan-2.7-image") {
return "wan27image";
}
if (normalized === "wan-video/wan-2.7-image-pro") {
return "wan27imagepro";
}
return "unknown";
}
export function parseModelId(model: string): { owner: string; name: string; version: string | null } {
const [ownerName, version] = model.split(":");
const parts = ownerName!.split("/");
if (parts.length !== 2 || !parts[0] || !parts[1]) {
throw new Error(
`Invalid Replicate model format: "${model}". Expected "owner/name" or "owner/name:version".`
);
}
return { owner: parts[0], name: parts[1], version: version || null };
}
function parsePixelSize(value: string): PixelSize | null {
const match = value.trim().match(/^(\d+)\s*[xX*]\s*(\d+)$/);
if (!match) return null;
const width = parseInt(match[1]!, 10);
const height = parseInt(match[2]!, 10);
if (!Number.isFinite(width) || !Number.isFinite(height) || width <= 0 || height <= 0) {
return null;
}
return { width, height };
}
function parseAspectRatio(value: string): PixelSize | null {
const match = value.trim().match(/^(\d+)\s*:\s*(\d+)$/);
if (!match) return null;
const width = parseInt(match[1]!, 10);
const height = parseInt(match[2]!, 10);
if (!Number.isFinite(width) || !Number.isFinite(height) || width <= 0 || height <= 0) {
return null;
}
return { width, height };
}
function gcd(a: number, b: number): number {
let x = Math.abs(a);
let y = Math.abs(b);
while (y !== 0) {
const next = x % y;
x = y;
y = next;
}
return x || 1;
}
function inferAspectRatioFromSize(size: string): string | null {
const parsed = parsePixelSize(size);
if (!parsed) return null;
const divisor = gcd(parsed.width, parsed.height);
const normalized = `${parsed.width / divisor}:${parsed.height / divisor}`;
if (!DOCUMENTED_REPLICATE_ASPECT_RATIOS.has(normalized)) {
return null;
}
return normalized;
}
function getQualityPreset(args: CliArgs): "normal" | "2k" {
return args.quality === "normal" ? "normal" : "2k";
}
function validateDocumentedAspectRatio(model: string, aspectRatio: string): void {
if (aspectRatio === "match_input_image") {
return;
}
if (DOCUMENTED_REPLICATE_ASPECT_RATIOS.has(aspectRatio)) {
return;
}
throw new Error(
`Replicate model ${model} does not support aspect ratio ${aspectRatio}. Supported values: ${Array.from(DOCUMENTED_REPLICATE_ASPECT_RATIOS).join(", ")}`
);
}
function getRequestedAspectRatio(model: string, args: CliArgs): string | null {
if (args.aspectRatio) {
validateDocumentedAspectRatio(model, args.aspectRatio);
return args.aspectRatio;
}
if (!args.size) return null;
const inferred = inferAspectRatioFromSize(args.size);
if (!inferred) {
throw new Error(
`Replicate model ${model} cannot derive a supported aspect ratio from --size ${args.size}. Use one of: ${Array.from(DOCUMENTED_REPLICATE_ASPECT_RATIOS).join(", ")}`
);
}
return inferred;
}
function getNanoBananaResolution(args: CliArgs): "1K" | "2K" {
if (args.size) {
const parsed = parsePixelSize(args.size);
if (!parsed) {
throw new Error("Replicate nano-banana --size must be in WxH format, for example 1536x1024.");
}
const longestEdge = Math.max(parsed.width, parsed.height);
if (longestEdge <= 1024) return "1K";
if (longestEdge <= 2048) return "2K";
throw new Error("Replicate nano-banana only supports sizes that map to 1K or 2K output.");
}
return getQualityPreset(args) === "normal" ? "1K" : "2K";
}
function resolveSeedream45Size(args: CliArgs): "2K" | "4K" | string {
if (args.size) {
const upper = args.size.trim().toUpperCase();
if (upper === "2K" || upper === "4K") {
return upper;
}
const parsed = parsePixelSize(args.size);
if (!parsed) {
throw new Error("Replicate Seedream 4.5 --size must be 2K, 4K, or an explicit WxH size.");
}
if (parsed.width < 1024 || parsed.width > 4096 || parsed.height < 1024 || parsed.height > 4096) {
throw new Error("Replicate Seedream 4.5 custom --size must keep width and height between 1024 and 4096.");
}
return `${parsed.width}x${parsed.height}`;
}
return getQualityPreset(args) === "normal" ? "2K" : "4K";
}
function resolveSeedream5LiteSize(args: CliArgs): "2K" | "3K" {
if (args.size) {
const upper = args.size.trim().toUpperCase();
if (upper === "2K" || upper === "3K") {
return upper;
}
throw new Error("Replicate Seedream 5 Lite currently supports 2K or 3K output in this tool.");
}
return getQualityPreset(args) === "normal" ? "2K" : "3K";
}
function formatCustomWanSize(size: PixelSize): string {
return `${size.width}*${size.height}`;
}
function resolveWanSizeFromAspectRatio(
aspectRatio: string,
maxDimension: number,
): string {
const parsedRatio = parseAspectRatio(aspectRatio);
if (!parsedRatio) {
throw new Error(`Replicate Wan aspect ratio must be in W:H format, got ${aspectRatio}.`);
}
const scale = Math.min(maxDimension / parsedRatio.width, maxDimension / parsedRatio.height);
const width = Math.max(1, Math.floor(parsedRatio.width * scale));
const height = Math.max(1, Math.floor(parsedRatio.height * scale));
return formatCustomWanSize({ width, height });
}
function resolveWanSize(family: "wan27image" | "wan27imagepro", args: CliArgs): "1K" | "2K" | "4K" | string {
const referenceMode = args.referenceImages.length > 0;
const maxDimension = family === "wan27imagepro" && !referenceMode ? 4096 : 2048;
if (args.size) {
const upper = args.size.trim().toUpperCase();
if (upper === "1K" || upper === "2K" || upper === "4K") {
if (upper === "4K" && family !== "wan27imagepro") {
throw new Error("Replicate Wan 2.7 Image only supports 1K, 2K, or custom sizes up to 2048px.");
}
if (upper === "4K" && referenceMode) {
throw new Error("Replicate Wan 2.7 Image Pro only supports 4K text-to-image. Remove --ref or lower the size.");
}
return upper;
}
const parsed = parsePixelSize(args.size);
if (!parsed) {
throw new Error("Replicate Wan --size must be 1K, 2K, 4K, or an explicit WxH size.");
}
if (parsed.width > maxDimension || parsed.height > maxDimension) {
throw new Error(
`Replicate ${family === "wan27imagepro" ? "Wan 2.7 Image Pro" : "Wan 2.7 Image"} custom --size must keep width and height at or below ${maxDimension}px in the current mode.`
);
}
return formatCustomWanSize(parsed);
}
if (args.aspectRatio) {
return resolveWanSizeFromAspectRatio(
args.aspectRatio,
getQualityPreset(args) === "normal" ? 1024 : 2048,
);
}
return getQualityPreset(args) === "normal" ? "1K" : "2K";
}
function buildNanoBananaInput(
prompt: string,
model: string,
args: CliArgs,
referenceImages: string[],
): Record<string, unknown> {
const input: Record<string, unknown> = {
prompt,
resolution: getNanoBananaResolution(args),
output_format: "png",
};
const aspectRatio = getRequestedAspectRatio(model, args);
if (aspectRatio) {
input.aspect_ratio = aspectRatio;
} else if (referenceImages.length > 0) {
input.aspect_ratio = "match_input_image";
}
if (referenceImages.length > 0) {
input.image_input = referenceImages;
}
return input;
}
function buildSeedreamInput(
family: "seedream45" | "seedream5lite",
prompt: string,
model: string,
args: CliArgs,
referenceImages: string[],
): Record<string, unknown> {
const input: Record<string, unknown> = {
prompt,
size: family === "seedream45" ? resolveSeedream45Size(args) : resolveSeedream5LiteSize(args),
};
if (referenceImages.length > 0) {
input.image_input = referenceImages;
}
if (args.aspectRatio) {
validateDocumentedAspectRatio(model, args.aspectRatio);
input.aspect_ratio = args.aspectRatio;
} else if (referenceImages.length > 0 && family === "seedream45") {
input.aspect_ratio = "match_input_image";
}
return input;
}
function buildWanInput(
family: "wan27image" | "wan27imagepro",
prompt: string,
args: CliArgs,
referenceImages: string[],
): Record<string, unknown> {
const input: Record<string, unknown> = {
prompt,
size: resolveWanSize(family, args),
};
if (referenceImages.length > 0) {
input.images = referenceImages;
}
return input;
}
export function validateArgs(model: string, args: CliArgs): void {
parseModelId(model);
if (args.n !== 1) {
throw new Error("Replicate integration currently supports exactly one output image per request. Remove --n or use --n 1.");
}
if (args.imageSize && args.imageSizeSource !== "config") {
throw new Error("Replicate models in baoyu-imagine do not use --imageSize. Use --quality, --ar, or --size instead.");
}
const family = getModelFamily(model);
if (family === "nano-banana") {
if (args.referenceImages.length > 14) {
throw new Error("Replicate nano-banana supports at most 14 reference images.");
}
if (args.aspectRatio) {
validateDocumentedAspectRatio(model, args.aspectRatio);
}
if (args.size) {
getRequestedAspectRatio(model, args);
getNanoBananaResolution(args);
}
return;
}
if (family === "seedream45") {
if (args.referenceImages.length > 14) {
throw new Error("Replicate Seedream 4.5 supports at most 14 reference images.");
}
if (args.aspectRatio) {
validateDocumentedAspectRatio(model, args.aspectRatio);
}
resolveSeedream45Size(args);
return;
}
if (family === "seedream5lite") {
if (args.referenceImages.length > 14) {
throw new Error("Replicate Seedream 5 Lite supports at most 14 reference images.");
}
if (args.aspectRatio) {
validateDocumentedAspectRatio(model, args.aspectRatio);
}
resolveSeedream5LiteSize(args);
return;
}
if (family === "wan27image" || family === "wan27imagepro") {
if (args.referenceImages.length > 9) {
throw new Error("Replicate Wan 2.7 image models support at most 9 reference images.");
}
if (args.aspectRatio) {
const parsed = parseAspectRatio(args.aspectRatio);
if (!parsed) {
throw new Error(`Replicate Wan aspect ratio must be in W:H format, got ${args.aspectRatio}.`);
}
}
resolveWanSize(family, args);
return;
}
if (args.referenceImages.length > 0 || args.aspectRatio || args.size) {
throw new Error(
`Replicate model ${model} is not in the baoyu-imagine compatibility list. Supported families: google/nano-banana*, bytedance/seedream-4.5, bytedance/seedream-5-lite, wan-video/wan-2.7-image, wan-video/wan-2.7-image-pro.`
);
}
}
export function getDefaultOutputExtension(model: string): ".png" {
const _family = getModelFamily(model);
return ".png";
}
export function buildInput(
model: string,
prompt: string,
args: CliArgs,
referenceImages: string[],
): Record<string, unknown> {
const family = getModelFamily(model);
if (family === "nano-banana") {
return buildNanoBananaInput(prompt, model, args, referenceImages);
}
if (family === "seedream45" || family === "seedream5lite") {
return buildSeedreamInput(family, prompt, model, args, referenceImages);
}
if (family === "wan27image" || family === "wan27imagepro") {
return buildWanInput(family, prompt, args, referenceImages);
}
return { prompt };
}
async function readImageAsDataUrl(p: string): Promise<string> {
const buf = await readFile(p);
const ext = path.extname(p).toLowerCase();
let mimeType = "image/png";
if (ext === ".jpg" || ext === ".jpeg") mimeType = "image/jpeg";
else if (ext === ".gif") mimeType = "image/gif";
else if (ext === ".webp") mimeType = "image/webp";
return `data:${mimeType};base64,${buf.toString("base64")}`;
}
type PredictionResponse = {
id: string;
status: string;
output: unknown;
error: string | null;
urls?: { get?: string };
};
async function createPrediction(
apiToken: string,
model: { owner: string; name: string; version: string | null },
input: Record<string, unknown>,
sync: boolean
): Promise<PredictionResponse> {
const baseUrl = getBaseUrl();
let url: string;
const body: Record<string, unknown> = { input };
if (model.version) {
url = `${baseUrl}/v1/predictions`;
body.version = model.version;
} else {
url = `${baseUrl}/v1/models/${model.owner}/${model.name}/predictions`;
}
const headers: Record<string, string> = {
Authorization: `Bearer ${apiToken}`,
"Content-Type": "application/json",
};
if (sync) {
headers["Prefer"] = `wait=${SYNC_WAIT_SECONDS}`;
}
const res = await fetch(url, {
method: "POST",
headers,
body: JSON.stringify(body),
});
if (!res.ok) {
const err = await res.text();
throw new Error(`Replicate API error (${res.status}): ${err}`);
}
return (await res.json()) as PredictionResponse;
}
async function pollPrediction(apiToken: string, getUrl: string): Promise<PredictionResponse> {
const start = Date.now();
while (Date.now() - start < MAX_POLL_MS) {
const res = await fetch(getUrl, {
headers: { Authorization: `Bearer ${apiToken}` },
});
if (!res.ok) {
const err = await res.text();
throw new Error(`Replicate poll error (${res.status}): ${err}`);
}
const prediction = (await res.json()) as PredictionResponse;
if (prediction.status === "succeeded") return prediction;
if (prediction.status === "failed" || prediction.status === "canceled") {
throw new Error(`Replicate prediction ${prediction.status}: ${prediction.error || "unknown error"}`);
}
await new Promise((r) => setTimeout(r, POLL_INTERVAL_MS));
}
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
}
export function extractOutputUrl(prediction: PredictionResponse): string {
const output = prediction.output;
if (typeof output === "string") return output;
if (Array.isArray(output)) {
if (output.length !== 1) {
throw new Error(
`Replicate returned ${output.length} outputs, but baoyu-imagine currently supports saving exactly one image per request.`
);
}
const first = output[0];
if (typeof first === "string") return first;
}
if (output && typeof output === "object" && "url" in output) {
const url = (output as Record<string, unknown>).url;
if (typeof url === "string") return url;
}
throw new Error(`Unexpected Replicate output format: ${JSON.stringify(output)}`);
}
async function downloadImage(url: string): Promise<Uint8Array> {
const res = await fetch(url);
if (!res.ok) throw new Error(`Failed to download image from Replicate: ${res.status}`);
const buf = await res.arrayBuffer();
return new Uint8Array(buf);
}
export async function generateImage(
prompt: string,
model: string,
args: CliArgs
): Promise<Uint8Array> {
const apiToken = getApiToken();
if (!apiToken) throw new Error("REPLICATE_API_TOKEN is required. Get one at https://replicate.com/account/api-tokens");
const parsedModel = parseModelId(model);
validateArgs(model, args);
const refDataUrls: string[] = [];
for (const refPath of args.referenceImages) {
refDataUrls.push(await readImageAsDataUrl(refPath));
}
const input = buildInput(model, prompt, args, refDataUrls);
console.log(`Generating image with Replicate (${model})...`);
let prediction = await createPrediction(apiToken, parsedModel, input, true);
if (prediction.status !== "succeeded") {
if (!prediction.urls?.get) {
throw new Error("Replicate prediction did not return a poll URL");
}
console.log("Waiting for prediction to complete...");
prediction = await pollPrediction(apiToken, prediction.urls.get);
}
console.log("Generation completed.");
const outputUrl = extractOutputUrl(prediction);
return downloadImage(outputUrl);
}