122 lines
4.0 KiB
Python
122 lines
4.0 KiB
Python
from pydantic import BaseModel, Field
|
|
from typing import List, Optional
|
|
from pydantic_ai import Agent, RunContext
|
|
import os
|
|
from openai import AsyncOpenAI
|
|
from pydantic_ai.providers import Provider
|
|
|
|
# --- Data Models ---
|
|
|
|
class Character(BaseModel):
|
|
name: str = Field(description="The name of the character.")
|
|
description: str = Field(description="A concise visual description of the character (e.g., 'Blonde hair, blue dress, young').")
|
|
|
|
class CharacterAnalysisResult(BaseModel):
|
|
characters: List[Character] = Field(description="List of main characters identified in the text.")
|
|
|
|
class MangaSimplePrompt(BaseModel):
|
|
prompt: str = Field(description="The generated English manga image prompt.")
|
|
|
|
# --- Custom Provider ---
|
|
|
|
class OpenAIAuthProvider(Provider[AsyncOpenAI]):
|
|
"""
|
|
Custom provider to handle dynamic API Key and Base URL.
|
|
"""
|
|
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
|
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
|
|
@property
|
|
def client(self) -> AsyncOpenAI:
|
|
return self._client
|
|
|
|
@property
|
|
def base_url(self) -> str:
|
|
return str(self._client.base_url)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "openai"
|
|
|
|
# --- Wrapper Functions ---
|
|
|
|
async def analyze_characters_with_agent(text: str, api_key: str, base_url: Optional[str] = None, model: str = "gpt-4o") -> str:
|
|
"""
|
|
Uses PydanticAI Agent to analyze characters and returns a formatted string context.
|
|
"""
|
|
# Use OpenAIChatModel as OpenAIModel is deprecated
|
|
try:
|
|
from pydantic_ai.models.openai import OpenAIChatModel as OpenAIModel
|
|
except ImportError:
|
|
from pydantic_ai.models.openai import OpenAIModel
|
|
|
|
# Allow model override
|
|
model_name = model if model else "gpt-4o"
|
|
|
|
# Create the custom provider
|
|
provider = OpenAIAuthProvider(
|
|
api_key=api_key if api_key else "dummy",
|
|
base_url=base_url
|
|
)
|
|
|
|
# Create the model instance using the custom provider
|
|
openai_model = OpenAIModel(
|
|
model_name,
|
|
provider=provider
|
|
)
|
|
|
|
# Create a temporary agent for this run
|
|
agent = Agent(
|
|
openai_model,
|
|
output_type=CharacterAnalysisResult,
|
|
system_prompt="You are a professional manga editor. Analyze the provided novel text and extract the visual descriptions of the main characters to ensure consistency in manga adaptation."
|
|
)
|
|
|
|
try:
|
|
result = await agent.run(text)
|
|
|
|
# Format the result into a string context
|
|
context_str = ""
|
|
for char in result.output.characters:
|
|
context_str += f"- {char.name}: {char.description}\n"
|
|
return context_str
|
|
except Exception as e:
|
|
print(f"Agent Error (Characters): {e}")
|
|
return ""
|
|
|
|
async def generate_single_prompt_with_agent(paragraph: str, character_context: str, api_key: str, base_url: Optional[str] = None, model: str = "gpt-4o") -> str:
|
|
try:
|
|
from pydantic_ai.models.openai import OpenAIChatModel as OpenAIModel
|
|
except ImportError:
|
|
from pydantic_ai.models.openai import OpenAIModel
|
|
|
|
model_name = model if model else "gpt-4o"
|
|
|
|
provider = OpenAIAuthProvider(
|
|
api_key=api_key if api_key else "dummy",
|
|
base_url=base_url
|
|
)
|
|
|
|
openai_model = OpenAIModel(
|
|
model_name,
|
|
provider=provider
|
|
)
|
|
|
|
agent = Agent(
|
|
openai_model,
|
|
output_type=MangaSimplePrompt,
|
|
deps_type=str,
|
|
system_prompt=(
|
|
"You are a professional manga artist assistant. Convert the novel text into a detailed manga image prompt. "
|
|
"The prompt MUST be in English. Focus on visual details, character appearance, setting, and style (monochrome, manga style, high quality). "
|
|
f"Maintain character consistency:\n{character_context}"
|
|
)
|
|
)
|
|
|
|
try:
|
|
result = await agent.run(paragraph, deps=character_context)
|
|
return result.output.prompt
|
|
except Exception as e:
|
|
print(f"Agent Error (Prompt): {e}")
|
|
return f"Error generation prompt: {e}"
|