605 lines
17 KiB
TypeScript
605 lines
17 KiB
TypeScript
import path from "node:path";
|
|
import { readFile } from "node:fs/promises";
|
|
import type { CliArgs } from "../types";
|
|
|
|
const DEFAULT_MODEL = "google/nano-banana-2";
|
|
const SYNC_WAIT_SECONDS = 60;
|
|
const POLL_INTERVAL_MS = 2000;
|
|
const MAX_POLL_MS = 300_000;
|
|
const DOCUMENTED_REPLICATE_ASPECT_RATIOS = new Set([
|
|
"1:1",
|
|
"2:3",
|
|
"3:2",
|
|
"3:4",
|
|
"4:3",
|
|
"5:4",
|
|
"4:5",
|
|
"9:16",
|
|
"16:9",
|
|
"21:9",
|
|
]);
|
|
|
|
export type ReplicateModelFamily =
|
|
| "nano-banana"
|
|
| "seedream45"
|
|
| "seedream5lite"
|
|
| "wan27image"
|
|
| "wan27imagepro"
|
|
| "unknown";
|
|
|
|
type PixelSize = {
|
|
width: number;
|
|
height: number;
|
|
};
|
|
|
|
export function getDefaultModel(): string {
|
|
return process.env.REPLICATE_IMAGE_MODEL || DEFAULT_MODEL;
|
|
}
|
|
|
|
function getApiToken(): string | null {
|
|
return process.env.REPLICATE_API_TOKEN || null;
|
|
}
|
|
|
|
function getBaseUrl(): string {
|
|
const base = process.env.REPLICATE_BASE_URL || "https://api.replicate.com";
|
|
return base.replace(/\/+$/g, "");
|
|
}
|
|
|
|
function normalizeModelId(model: string): string {
|
|
return model.trim().toLowerCase().split(":")[0]!;
|
|
}
|
|
|
|
export function getModelFamily(model: string): ReplicateModelFamily {
|
|
const normalized = normalizeModelId(model);
|
|
|
|
if (
|
|
normalized === "google/nano-banana" ||
|
|
normalized === "google/nano-banana-pro" ||
|
|
normalized === "google/nano-banana-2"
|
|
) {
|
|
return "nano-banana";
|
|
}
|
|
|
|
if (normalized === "bytedance/seedream-4.5") {
|
|
return "seedream45";
|
|
}
|
|
|
|
if (normalized === "bytedance/seedream-5-lite") {
|
|
return "seedream5lite";
|
|
}
|
|
|
|
if (normalized === "wan-video/wan-2.7-image") {
|
|
return "wan27image";
|
|
}
|
|
|
|
if (normalized === "wan-video/wan-2.7-image-pro") {
|
|
return "wan27imagepro";
|
|
}
|
|
|
|
return "unknown";
|
|
}
|
|
|
|
export function parseModelId(model: string): { owner: string; name: string; version: string | null } {
|
|
const [ownerName, version] = model.split(":");
|
|
const parts = ownerName!.split("/");
|
|
if (parts.length !== 2 || !parts[0] || !parts[1]) {
|
|
throw new Error(
|
|
`Invalid Replicate model format: "${model}". Expected "owner/name" or "owner/name:version".`
|
|
);
|
|
}
|
|
return { owner: parts[0], name: parts[1], version: version || null };
|
|
}
|
|
|
|
function parsePixelSize(value: string): PixelSize | null {
|
|
const match = value.trim().match(/^(\d+)\s*[xX*]\s*(\d+)$/);
|
|
if (!match) return null;
|
|
|
|
const width = parseInt(match[1]!, 10);
|
|
const height = parseInt(match[2]!, 10);
|
|
if (!Number.isFinite(width) || !Number.isFinite(height) || width <= 0 || height <= 0) {
|
|
return null;
|
|
}
|
|
|
|
return { width, height };
|
|
}
|
|
|
|
function parseAspectRatio(value: string): PixelSize | null {
|
|
const match = value.trim().match(/^(\d+)\s*:\s*(\d+)$/);
|
|
if (!match) return null;
|
|
|
|
const width = parseInt(match[1]!, 10);
|
|
const height = parseInt(match[2]!, 10);
|
|
if (!Number.isFinite(width) || !Number.isFinite(height) || width <= 0 || height <= 0) {
|
|
return null;
|
|
}
|
|
|
|
return { width, height };
|
|
}
|
|
|
|
function gcd(a: number, b: number): number {
|
|
let x = Math.abs(a);
|
|
let y = Math.abs(b);
|
|
|
|
while (y !== 0) {
|
|
const next = x % y;
|
|
x = y;
|
|
y = next;
|
|
}
|
|
|
|
return x || 1;
|
|
}
|
|
|
|
function inferAspectRatioFromSize(size: string): string | null {
|
|
const parsed = parsePixelSize(size);
|
|
if (!parsed) return null;
|
|
|
|
const divisor = gcd(parsed.width, parsed.height);
|
|
const normalized = `${parsed.width / divisor}:${parsed.height / divisor}`;
|
|
if (!DOCUMENTED_REPLICATE_ASPECT_RATIOS.has(normalized)) {
|
|
return null;
|
|
}
|
|
|
|
return normalized;
|
|
}
|
|
|
|
function getQualityPreset(args: CliArgs): "normal" | "2k" {
|
|
return args.quality === "normal" ? "normal" : "2k";
|
|
}
|
|
|
|
function validateDocumentedAspectRatio(model: string, aspectRatio: string): void {
|
|
if (aspectRatio === "match_input_image") {
|
|
return;
|
|
}
|
|
|
|
if (DOCUMENTED_REPLICATE_ASPECT_RATIOS.has(aspectRatio)) {
|
|
return;
|
|
}
|
|
|
|
throw new Error(
|
|
`Replicate model ${model} does not support aspect ratio ${aspectRatio}. Supported values: ${Array.from(DOCUMENTED_REPLICATE_ASPECT_RATIOS).join(", ")}`
|
|
);
|
|
}
|
|
|
|
function getRequestedAspectRatio(model: string, args: CliArgs): string | null {
|
|
if (args.aspectRatio) {
|
|
validateDocumentedAspectRatio(model, args.aspectRatio);
|
|
return args.aspectRatio;
|
|
}
|
|
|
|
if (!args.size) return null;
|
|
|
|
const inferred = inferAspectRatioFromSize(args.size);
|
|
if (!inferred) {
|
|
throw new Error(
|
|
`Replicate model ${model} cannot derive a supported aspect ratio from --size ${args.size}. Use one of: ${Array.from(DOCUMENTED_REPLICATE_ASPECT_RATIOS).join(", ")}`
|
|
);
|
|
}
|
|
|
|
return inferred;
|
|
}
|
|
|
|
function getNanoBananaResolution(args: CliArgs): "1K" | "2K" {
|
|
if (args.size) {
|
|
const parsed = parsePixelSize(args.size);
|
|
if (!parsed) {
|
|
throw new Error("Replicate nano-banana --size must be in WxH format, for example 1536x1024.");
|
|
}
|
|
|
|
const longestEdge = Math.max(parsed.width, parsed.height);
|
|
if (longestEdge <= 1024) return "1K";
|
|
if (longestEdge <= 2048) return "2K";
|
|
throw new Error("Replicate nano-banana only supports sizes that map to 1K or 2K output.");
|
|
}
|
|
|
|
return getQualityPreset(args) === "normal" ? "1K" : "2K";
|
|
}
|
|
|
|
function resolveSeedream45Size(args: CliArgs): "2K" | "4K" | string {
|
|
if (args.size) {
|
|
const upper = args.size.trim().toUpperCase();
|
|
if (upper === "2K" || upper === "4K") {
|
|
return upper;
|
|
}
|
|
|
|
const parsed = parsePixelSize(args.size);
|
|
if (!parsed) {
|
|
throw new Error("Replicate Seedream 4.5 --size must be 2K, 4K, or an explicit WxH size.");
|
|
}
|
|
if (parsed.width < 1024 || parsed.width > 4096 || parsed.height < 1024 || parsed.height > 4096) {
|
|
throw new Error("Replicate Seedream 4.5 custom --size must keep width and height between 1024 and 4096.");
|
|
}
|
|
return `${parsed.width}x${parsed.height}`;
|
|
}
|
|
|
|
return getQualityPreset(args) === "normal" ? "2K" : "4K";
|
|
}
|
|
|
|
function resolveSeedream5LiteSize(args: CliArgs): "2K" | "3K" {
|
|
if (args.size) {
|
|
const upper = args.size.trim().toUpperCase();
|
|
if (upper === "2K" || upper === "3K") {
|
|
return upper;
|
|
}
|
|
|
|
throw new Error("Replicate Seedream 5 Lite currently supports 2K or 3K output in this tool.");
|
|
}
|
|
|
|
return getQualityPreset(args) === "normal" ? "2K" : "3K";
|
|
}
|
|
|
|
function formatCustomWanSize(size: PixelSize): string {
|
|
return `${size.width}*${size.height}`;
|
|
}
|
|
|
|
function resolveWanSizeFromAspectRatio(
|
|
aspectRatio: string,
|
|
maxDimension: number,
|
|
): string {
|
|
const parsedRatio = parseAspectRatio(aspectRatio);
|
|
if (!parsedRatio) {
|
|
throw new Error(`Replicate Wan aspect ratio must be in W:H format, got ${aspectRatio}.`);
|
|
}
|
|
|
|
const scale = Math.min(maxDimension / parsedRatio.width, maxDimension / parsedRatio.height);
|
|
const width = Math.max(1, Math.floor(parsedRatio.width * scale));
|
|
const height = Math.max(1, Math.floor(parsedRatio.height * scale));
|
|
return formatCustomWanSize({ width, height });
|
|
}
|
|
|
|
function resolveWanSize(family: "wan27image" | "wan27imagepro", args: CliArgs): "1K" | "2K" | "4K" | string {
|
|
const referenceMode = args.referenceImages.length > 0;
|
|
const maxDimension = family === "wan27imagepro" && !referenceMode ? 4096 : 2048;
|
|
|
|
if (args.size) {
|
|
const upper = args.size.trim().toUpperCase();
|
|
if (upper === "1K" || upper === "2K" || upper === "4K") {
|
|
if (upper === "4K" && family !== "wan27imagepro") {
|
|
throw new Error("Replicate Wan 2.7 Image only supports 1K, 2K, or custom sizes up to 2048px.");
|
|
}
|
|
if (upper === "4K" && referenceMode) {
|
|
throw new Error("Replicate Wan 2.7 Image Pro only supports 4K text-to-image. Remove --ref or lower the size.");
|
|
}
|
|
return upper;
|
|
}
|
|
|
|
const parsed = parsePixelSize(args.size);
|
|
if (!parsed) {
|
|
throw new Error("Replicate Wan --size must be 1K, 2K, 4K, or an explicit WxH size.");
|
|
}
|
|
if (parsed.width > maxDimension || parsed.height > maxDimension) {
|
|
throw new Error(
|
|
`Replicate ${family === "wan27imagepro" ? "Wan 2.7 Image Pro" : "Wan 2.7 Image"} custom --size must keep width and height at or below ${maxDimension}px in the current mode.`
|
|
);
|
|
}
|
|
return formatCustomWanSize(parsed);
|
|
}
|
|
|
|
if (args.aspectRatio) {
|
|
return resolveWanSizeFromAspectRatio(
|
|
args.aspectRatio,
|
|
getQualityPreset(args) === "normal" ? 1024 : 2048,
|
|
);
|
|
}
|
|
|
|
return getQualityPreset(args) === "normal" ? "1K" : "2K";
|
|
}
|
|
|
|
function buildNanoBananaInput(
|
|
prompt: string,
|
|
model: string,
|
|
args: CliArgs,
|
|
referenceImages: string[],
|
|
): Record<string, unknown> {
|
|
const input: Record<string, unknown> = {
|
|
prompt,
|
|
resolution: getNanoBananaResolution(args),
|
|
output_format: "png",
|
|
};
|
|
|
|
const aspectRatio = getRequestedAspectRatio(model, args);
|
|
if (aspectRatio) {
|
|
input.aspect_ratio = aspectRatio;
|
|
} else if (referenceImages.length > 0) {
|
|
input.aspect_ratio = "match_input_image";
|
|
}
|
|
|
|
if (referenceImages.length > 0) {
|
|
input.image_input = referenceImages;
|
|
}
|
|
|
|
return input;
|
|
}
|
|
|
|
function buildSeedreamInput(
|
|
family: "seedream45" | "seedream5lite",
|
|
prompt: string,
|
|
model: string,
|
|
args: CliArgs,
|
|
referenceImages: string[],
|
|
): Record<string, unknown> {
|
|
const input: Record<string, unknown> = {
|
|
prompt,
|
|
size: family === "seedream45" ? resolveSeedream45Size(args) : resolveSeedream5LiteSize(args),
|
|
};
|
|
|
|
if (referenceImages.length > 0) {
|
|
input.image_input = referenceImages;
|
|
}
|
|
|
|
if (args.aspectRatio) {
|
|
validateDocumentedAspectRatio(model, args.aspectRatio);
|
|
input.aspect_ratio = args.aspectRatio;
|
|
} else if (referenceImages.length > 0 && family === "seedream45") {
|
|
input.aspect_ratio = "match_input_image";
|
|
}
|
|
|
|
return input;
|
|
}
|
|
|
|
function buildWanInput(
|
|
family: "wan27image" | "wan27imagepro",
|
|
prompt: string,
|
|
args: CliArgs,
|
|
referenceImages: string[],
|
|
): Record<string, unknown> {
|
|
const input: Record<string, unknown> = {
|
|
prompt,
|
|
size: resolveWanSize(family, args),
|
|
};
|
|
|
|
if (referenceImages.length > 0) {
|
|
input.images = referenceImages;
|
|
}
|
|
|
|
return input;
|
|
}
|
|
|
|
export function validateArgs(model: string, args: CliArgs): void {
|
|
parseModelId(model);
|
|
|
|
if (args.n !== 1) {
|
|
throw new Error("Replicate integration currently supports exactly one output image per request. Remove --n or use --n 1.");
|
|
}
|
|
|
|
if (args.imageSize && args.imageSizeSource !== "config") {
|
|
throw new Error("Replicate models in baoyu-imagine do not use --imageSize. Use --quality, --ar, or --size instead.");
|
|
}
|
|
|
|
const family = getModelFamily(model);
|
|
|
|
if (family === "nano-banana") {
|
|
if (args.referenceImages.length > 14) {
|
|
throw new Error("Replicate nano-banana supports at most 14 reference images.");
|
|
}
|
|
if (args.aspectRatio) {
|
|
validateDocumentedAspectRatio(model, args.aspectRatio);
|
|
}
|
|
if (args.size) {
|
|
getRequestedAspectRatio(model, args);
|
|
getNanoBananaResolution(args);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (family === "seedream45") {
|
|
if (args.referenceImages.length > 14) {
|
|
throw new Error("Replicate Seedream 4.5 supports at most 14 reference images.");
|
|
}
|
|
if (args.aspectRatio) {
|
|
validateDocumentedAspectRatio(model, args.aspectRatio);
|
|
}
|
|
resolveSeedream45Size(args);
|
|
return;
|
|
}
|
|
|
|
if (family === "seedream5lite") {
|
|
if (args.referenceImages.length > 14) {
|
|
throw new Error("Replicate Seedream 5 Lite supports at most 14 reference images.");
|
|
}
|
|
if (args.aspectRatio) {
|
|
validateDocumentedAspectRatio(model, args.aspectRatio);
|
|
}
|
|
resolveSeedream5LiteSize(args);
|
|
return;
|
|
}
|
|
|
|
if (family === "wan27image" || family === "wan27imagepro") {
|
|
if (args.referenceImages.length > 9) {
|
|
throw new Error("Replicate Wan 2.7 image models support at most 9 reference images.");
|
|
}
|
|
if (args.aspectRatio) {
|
|
const parsed = parseAspectRatio(args.aspectRatio);
|
|
if (!parsed) {
|
|
throw new Error(`Replicate Wan aspect ratio must be in W:H format, got ${args.aspectRatio}.`);
|
|
}
|
|
}
|
|
resolveWanSize(family, args);
|
|
return;
|
|
}
|
|
|
|
if (args.referenceImages.length > 0 || args.aspectRatio || args.size) {
|
|
throw new Error(
|
|
`Replicate model ${model} is not in the baoyu-imagine compatibility list. Supported families: google/nano-banana*, bytedance/seedream-4.5, bytedance/seedream-5-lite, wan-video/wan-2.7-image, wan-video/wan-2.7-image-pro.`
|
|
);
|
|
}
|
|
}
|
|
|
|
export function getDefaultOutputExtension(model: string): ".png" {
|
|
const _family = getModelFamily(model);
|
|
return ".png";
|
|
}
|
|
|
|
export function buildInput(
|
|
model: string,
|
|
prompt: string,
|
|
args: CliArgs,
|
|
referenceImages: string[],
|
|
): Record<string, unknown> {
|
|
const family = getModelFamily(model);
|
|
|
|
if (family === "nano-banana") {
|
|
return buildNanoBananaInput(prompt, model, args, referenceImages);
|
|
}
|
|
|
|
if (family === "seedream45" || family === "seedream5lite") {
|
|
return buildSeedreamInput(family, prompt, model, args, referenceImages);
|
|
}
|
|
|
|
if (family === "wan27image" || family === "wan27imagepro") {
|
|
return buildWanInput(family, prompt, args, referenceImages);
|
|
}
|
|
|
|
return { prompt };
|
|
}
|
|
|
|
async function readImageAsDataUrl(p: string): Promise<string> {
|
|
const buf = await readFile(p);
|
|
const ext = path.extname(p).toLowerCase();
|
|
let mimeType = "image/png";
|
|
if (ext === ".jpg" || ext === ".jpeg") mimeType = "image/jpeg";
|
|
else if (ext === ".gif") mimeType = "image/gif";
|
|
else if (ext === ".webp") mimeType = "image/webp";
|
|
return `data:${mimeType};base64,${buf.toString("base64")}`;
|
|
}
|
|
|
|
type PredictionResponse = {
|
|
id: string;
|
|
status: string;
|
|
output: unknown;
|
|
error: string | null;
|
|
urls?: { get?: string };
|
|
};
|
|
|
|
async function createPrediction(
|
|
apiToken: string,
|
|
model: { owner: string; name: string; version: string | null },
|
|
input: Record<string, unknown>,
|
|
sync: boolean
|
|
): Promise<PredictionResponse> {
|
|
const baseUrl = getBaseUrl();
|
|
|
|
let url: string;
|
|
const body: Record<string, unknown> = { input };
|
|
|
|
if (model.version) {
|
|
url = `${baseUrl}/v1/predictions`;
|
|
body.version = model.version;
|
|
} else {
|
|
url = `${baseUrl}/v1/models/${model.owner}/${model.name}/predictions`;
|
|
}
|
|
|
|
const headers: Record<string, string> = {
|
|
Authorization: `Bearer ${apiToken}`,
|
|
"Content-Type": "application/json",
|
|
};
|
|
|
|
if (sync) {
|
|
headers["Prefer"] = `wait=${SYNC_WAIT_SECONDS}`;
|
|
}
|
|
|
|
const res = await fetch(url, {
|
|
method: "POST",
|
|
headers,
|
|
body: JSON.stringify(body),
|
|
});
|
|
|
|
if (!res.ok) {
|
|
const err = await res.text();
|
|
throw new Error(`Replicate API error (${res.status}): ${err}`);
|
|
}
|
|
|
|
return (await res.json()) as PredictionResponse;
|
|
}
|
|
|
|
async function pollPrediction(apiToken: string, getUrl: string): Promise<PredictionResponse> {
|
|
const start = Date.now();
|
|
|
|
while (Date.now() - start < MAX_POLL_MS) {
|
|
const res = await fetch(getUrl, {
|
|
headers: { Authorization: `Bearer ${apiToken}` },
|
|
});
|
|
|
|
if (!res.ok) {
|
|
const err = await res.text();
|
|
throw new Error(`Replicate poll error (${res.status}): ${err}`);
|
|
}
|
|
|
|
const prediction = (await res.json()) as PredictionResponse;
|
|
|
|
if (prediction.status === "succeeded") return prediction;
|
|
if (prediction.status === "failed" || prediction.status === "canceled") {
|
|
throw new Error(`Replicate prediction ${prediction.status}: ${prediction.error || "unknown error"}`);
|
|
}
|
|
|
|
await new Promise((r) => setTimeout(r, POLL_INTERVAL_MS));
|
|
}
|
|
|
|
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
|
|
}
|
|
|
|
export function extractOutputUrl(prediction: PredictionResponse): string {
|
|
const output = prediction.output;
|
|
|
|
if (typeof output === "string") return output;
|
|
|
|
if (Array.isArray(output)) {
|
|
if (output.length !== 1) {
|
|
throw new Error(
|
|
`Replicate returned ${output.length} outputs, but baoyu-imagine currently supports saving exactly one image per request.`
|
|
);
|
|
}
|
|
const first = output[0];
|
|
if (typeof first === "string") return first;
|
|
}
|
|
|
|
if (output && typeof output === "object" && "url" in output) {
|
|
const url = (output as Record<string, unknown>).url;
|
|
if (typeof url === "string") return url;
|
|
}
|
|
|
|
throw new Error(`Unexpected Replicate output format: ${JSON.stringify(output)}`);
|
|
}
|
|
|
|
async function downloadImage(url: string): Promise<Uint8Array> {
|
|
const res = await fetch(url);
|
|
if (!res.ok) throw new Error(`Failed to download image from Replicate: ${res.status}`);
|
|
const buf = await res.arrayBuffer();
|
|
return new Uint8Array(buf);
|
|
}
|
|
|
|
export async function generateImage(
|
|
prompt: string,
|
|
model: string,
|
|
args: CliArgs
|
|
): Promise<Uint8Array> {
|
|
const apiToken = getApiToken();
|
|
if (!apiToken) throw new Error("REPLICATE_API_TOKEN is required. Get one at https://replicate.com/account/api-tokens");
|
|
|
|
const parsedModel = parseModelId(model);
|
|
validateArgs(model, args);
|
|
|
|
const refDataUrls: string[] = [];
|
|
for (const refPath of args.referenceImages) {
|
|
refDataUrls.push(await readImageAsDataUrl(refPath));
|
|
}
|
|
|
|
const input = buildInput(model, prompt, args, refDataUrls);
|
|
|
|
console.log(`Generating image with Replicate (${model})...`);
|
|
|
|
let prediction = await createPrediction(apiToken, parsedModel, input, true);
|
|
|
|
if (prediction.status !== "succeeded") {
|
|
if (!prediction.urls?.get) {
|
|
throw new Error("Replicate prediction did not return a poll URL");
|
|
}
|
|
console.log("Waiting for prediction to complete...");
|
|
prediction = await pollPrediction(apiToken, prediction.urls.get);
|
|
}
|
|
|
|
console.log("Generation completed.");
|
|
|
|
const outputUrl = extractOutputUrl(prediction);
|
|
return downloadImage(outputUrl);
|
|
}
|