JimLiu-baoyu-skills/skills/baoyu-image-gen/scripts/providers/replicate.ts

206 lines
5.8 KiB
TypeScript

import path from "node:path";
import { readFile } from "node:fs/promises";
import type { CliArgs } from "../types";
const DEFAULT_MODEL = "google/nano-banana-pro";
const SYNC_WAIT_SECONDS = 60;
const POLL_INTERVAL_MS = 2000;
const MAX_POLL_MS = 300_000;
export function getDefaultModel(): string {
return process.env.REPLICATE_IMAGE_MODEL || DEFAULT_MODEL;
}
function getApiToken(): string | null {
return process.env.REPLICATE_API_TOKEN || null;
}
function getBaseUrl(): string {
const base = process.env.REPLICATE_BASE_URL || "https://api.replicate.com";
return base.replace(/\/+$/g, "");
}
export function parseModelId(model: string): { owner: string; name: string; version: string | null } {
const [ownerName, version] = model.split(":");
const parts = ownerName!.split("/");
if (parts.length !== 2 || !parts[0] || !parts[1]) {
throw new Error(
`Invalid Replicate model format: "${model}". Expected "owner/name" or "owner/name:version".`
);
}
return { owner: parts[0], name: parts[1], version: version || null };
}
export function buildInput(prompt: string, args: CliArgs, referenceImages: string[]): Record<string, unknown> {
const input: Record<string, unknown> = { prompt };
if (args.aspectRatio) {
input.aspect_ratio = args.aspectRatio;
} else if (referenceImages.length > 0) {
input.aspect_ratio = "match_input_image";
}
if (args.n > 1) {
input.number_of_images = args.n;
}
if (args.quality === "normal") {
input.resolution = "1K";
} else if (args.quality === "2k") {
input.resolution = "2K";
}
input.output_format = "png";
if (referenceImages.length > 0) {
input.image_input = referenceImages;
}
return input;
}
async function readImageAsDataUrl(p: string): Promise<string> {
const buf = await readFile(p);
const ext = path.extname(p).toLowerCase();
let mimeType = "image/png";
if (ext === ".jpg" || ext === ".jpeg") mimeType = "image/jpeg";
else if (ext === ".gif") mimeType = "image/gif";
else if (ext === ".webp") mimeType = "image/webp";
return `data:${mimeType};base64,${buf.toString("base64")}`;
}
type PredictionResponse = {
id: string;
status: string;
output: unknown;
error: string | null;
urls?: { get?: string };
};
async function createPrediction(
apiToken: string,
model: { owner: string; name: string; version: string | null },
input: Record<string, unknown>,
sync: boolean
): Promise<PredictionResponse> {
const baseUrl = getBaseUrl();
let url: string;
const body: Record<string, unknown> = { input };
if (model.version) {
url = `${baseUrl}/v1/predictions`;
body.version = model.version;
} else {
url = `${baseUrl}/v1/models/${model.owner}/${model.name}/predictions`;
}
const headers: Record<string, string> = {
Authorization: `Bearer ${apiToken}`,
"Content-Type": "application/json",
};
if (sync) {
headers["Prefer"] = `wait=${SYNC_WAIT_SECONDS}`;
}
const res = await fetch(url, {
method: "POST",
headers,
body: JSON.stringify(body),
});
if (!res.ok) {
const err = await res.text();
throw new Error(`Replicate API error (${res.status}): ${err}`);
}
return (await res.json()) as PredictionResponse;
}
async function pollPrediction(apiToken: string, getUrl: string): Promise<PredictionResponse> {
const start = Date.now();
while (Date.now() - start < MAX_POLL_MS) {
const res = await fetch(getUrl, {
headers: { Authorization: `Bearer ${apiToken}` },
});
if (!res.ok) {
const err = await res.text();
throw new Error(`Replicate poll error (${res.status}): ${err}`);
}
const prediction = (await res.json()) as PredictionResponse;
if (prediction.status === "succeeded") return prediction;
if (prediction.status === "failed" || prediction.status === "canceled") {
throw new Error(`Replicate prediction ${prediction.status}: ${prediction.error || "unknown error"}`);
}
await new Promise((r) => setTimeout(r, POLL_INTERVAL_MS));
}
throw new Error(`Replicate prediction timed out after ${MAX_POLL_MS / 1000}s`);
}
export function extractOutputUrl(prediction: PredictionResponse): string {
const output = prediction.output;
if (typeof output === "string") return output;
if (Array.isArray(output)) {
const first = output[0];
if (typeof first === "string") return first;
}
if (output && typeof output === "object" && "url" in output) {
const url = (output as Record<string, unknown>).url;
if (typeof url === "string") return url;
}
throw new Error(`Unexpected Replicate output format: ${JSON.stringify(output)}`);
}
async function downloadImage(url: string): Promise<Uint8Array> {
const res = await fetch(url);
if (!res.ok) throw new Error(`Failed to download image from Replicate: ${res.status}`);
const buf = await res.arrayBuffer();
return new Uint8Array(buf);
}
export async function generateImage(
prompt: string,
model: string,
args: CliArgs
): Promise<Uint8Array> {
const apiToken = getApiToken();
if (!apiToken) throw new Error("REPLICATE_API_TOKEN is required. Get one at https://replicate.com/account/api-tokens");
const parsedModel = parseModelId(model);
const refDataUrls: string[] = [];
for (const refPath of args.referenceImages) {
refDataUrls.push(await readImageAsDataUrl(refPath));
}
const input = buildInput(prompt, args, refDataUrls);
console.log(`Generating image with Replicate (${model})...`);
let prediction = await createPrediction(apiToken, parsedModel, input, true);
if (prediction.status !== "succeeded") {
if (!prediction.urls?.get) {
throw new Error("Replicate prediction did not return a poll URL");
}
console.log("Waiting for prediction to complete...");
prediction = await pollPrediction(apiToken, prediction.urls.get);
}
console.log("Generation completed.");
const outputUrl = extractOutputUrl(prediction);
return downloadImage(outputUrl);
}