Files
gh-k-dense-ai-claude-scient…/skills/scientific-schematics/scripts/generate_schematic_ai.py
2025-11-30 08:30:14 +08:00

673 lines
25 KiB
Python

#!/usr/bin/env python3
"""
AI-powered scientific schematic generation using Nano Banana Pro.
This script uses an iterative refinement approach:
1. Generate initial image with Nano Banana Pro
2. AI quality review for scientific critique
3. Improve prompt based on critique and regenerate
4. Repeat for 3 iterations to achieve publication-quality results
Requirements:
- OPENROUTER_API_KEY environment variable
- requests library
Usage:
python generate_schematic_ai.py "Create a flowchart showing CONSORT participant flow" -o flowchart.png
python generate_schematic_ai.py "Neural network architecture diagram" -o architecture.png --iterations 3
"""
import argparse
import base64
import json
import os
import sys
import time
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
try:
import requests
except ImportError:
print("Error: requests library not found. Install with: pip install requests")
sys.exit(1)
# Try to load .env file from multiple potential locations
def _load_env_file():
"""Load .env file from current directory, parent directories, or package directory."""
try:
from dotenv import load_dotenv
from pathlib import Path
# Try current working directory first
if load_dotenv():
return True
# Try parent directories (up to 5 levels)
cwd = Path.cwd()
for _ in range(5):
env_path = cwd / ".env"
if env_path.exists():
load_dotenv(dotenv_path=env_path)
return True
cwd = cwd.parent
if cwd == cwd.parent: # Reached root
break
# Try the package's parent directory (scientific-writer project root)
script_dir = Path(__file__).resolve().parent
for _ in range(5):
env_path = script_dir / ".env"
if env_path.exists():
load_dotenv(dotenv_path=env_path)
return True
script_dir = script_dir.parent
if script_dir == script_dir.parent:
break
return False
except ImportError:
return False # python-dotenv not installed
_load_env_file()
class ScientificSchematicGenerator:
"""Generate scientific schematics using AI with iterative refinement."""
# Scientific diagram best practices prompt template
SCIENTIFIC_DIAGRAM_GUIDELINES = """
Create a high-quality scientific diagram with these requirements:
VISUAL QUALITY:
- Clean white or light background (no textures or gradients)
- High contrast for readability and printing
- Professional, publication-ready appearance
- Sharp, clear lines and text
- Adequate spacing between elements to prevent crowding
TYPOGRAPHY:
- Clear, readable sans-serif fonts (Arial, Helvetica style)
- Minimum 10pt font size for all labels
- Consistent font sizes throughout
- All text horizontal or clearly readable
- No overlapping text
SCIENTIFIC STANDARDS:
- Accurate representation of concepts
- Clear labels for all components
- Include scale bars, legends, or axes where appropriate
- Use standard scientific notation and symbols
- Include units where applicable
ACCESSIBILITY:
- Colorblind-friendly color palette (use Okabe-Ito colors if using color)
- High contrast between elements
- Redundant encoding (shapes + colors, not just colors)
- Works well in grayscale
LAYOUT:
- Logical flow (left-to-right or top-to-bottom)
- Clear visual hierarchy
- Balanced composition
- Appropriate use of whitespace
- No clutter or unnecessary decorative elements
"""
def __init__(self, api_key: Optional[str] = None, verbose: bool = False):
"""
Initialize the generator.
Args:
api_key: OpenRouter API key (or use OPENROUTER_API_KEY env var)
verbose: Print detailed progress information
"""
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
if not self.api_key:
raise ValueError("OPENROUTER_API_KEY environment variable not set or api_key not provided")
self.verbose = verbose
self.base_url = "https://openrouter.ai/api/v1"
self.image_model = "google/gemini-3-pro-image-preview"
# Use vision-capable model for review (Gemini Pro Vision or Claude Sonnet)
self.review_model = "google/gemini-pro-vision"
def _log(self, message: str):
"""Log message if verbose mode is enabled."""
if self.verbose:
print(f"[{time.strftime('%H:%M:%S')}] {message}")
def _make_request(self, model: str, messages: List[Dict[str, Any]],
modalities: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Make a request to OpenRouter API.
Args:
model: Model identifier
messages: List of message dictionaries
modalities: Optional list of modalities (e.g., ["image", "text"])
Returns:
API response as dictionary
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://github.com/scientific-writer",
"X-Title": "Scientific Schematic Generator"
}
payload = {
"model": model,
"messages": messages
}
if modalities:
payload["modalities"] = modalities
self._log(f"Making request to {model}...")
try:
response = requests.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload,
timeout=120
)
# Try to get response body even on error
try:
response_json = response.json()
except json.JSONDecodeError:
response_json = {"raw_text": response.text[:500]}
# Check for HTTP errors but include response body in error message
if response.status_code != 200:
error_detail = response_json.get("error", response_json)
self._log(f"HTTP {response.status_code}: {error_detail}")
raise RuntimeError(f"API request failed (HTTP {response.status_code}): {error_detail}")
return response_json
except requests.exceptions.Timeout:
raise RuntimeError("API request timed out after 120 seconds")
except requests.exceptions.RequestException as e:
raise RuntimeError(f"API request failed: {str(e)}")
def _extract_image_from_response(self, response: Dict[str, Any]) -> Optional[bytes]:
"""
Extract base64-encoded image from API response.
For Nano Banana Pro, images are returned in the 'images' field of the message,
not in the 'content' field.
Args:
response: API response dictionary
Returns:
Image bytes or None if not found
"""
try:
choices = response.get("choices", [])
if not choices:
self._log("No choices in response")
return None
message = choices[0].get("message", {})
# IMPORTANT: Nano Banana Pro returns images in the 'images' field
images = message.get("images", [])
if images and len(images) > 0:
self._log(f"Found {len(images)} image(s) in 'images' field")
# Get first image
first_image = images[0]
if isinstance(first_image, dict):
# Extract image_url
if first_image.get("type") == "image_url":
url = first_image.get("image_url", {})
if isinstance(url, dict):
url = url.get("url", "")
if url and url.startswith("data:image"):
# Extract base64 data after comma
if "," in url:
base64_str = url.split(",", 1)[1]
# Clean whitespace
base64_str = base64_str.replace('\n', '').replace('\r', '').replace(' ', '')
self._log(f"Extracted base64 data (length: {len(base64_str)})")
return base64.b64decode(base64_str)
# Fallback: check content field (for other models or future changes)
content = message.get("content", "")
if self.verbose:
self._log(f"Content type: {type(content)}, length: {len(str(content))}")
# Handle string content
if isinstance(content, str) and "data:image" in content:
import re
match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=\n\r]+)', content, re.DOTALL)
if match:
base64_str = match.group(1).replace('\n', '').replace('\r', '').replace(' ', '')
self._log(f"Found image in content field (length: {len(base64_str)})")
return base64.b64decode(base64_str)
# Handle list content
if isinstance(content, list):
for i, block in enumerate(content):
if isinstance(block, dict) and block.get("type") == "image_url":
url = block.get("image_url", {})
if isinstance(url, dict):
url = url.get("url", "")
if url and url.startswith("data:image") and "," in url:
base64_str = url.split(",", 1)[1].replace('\n', '').replace('\r', '').replace(' ', '')
self._log(f"Found image in content block {i}")
return base64.b64decode(base64_str)
self._log("No image data found in response")
return None
except Exception as e:
self._log(f"Error extracting image: {str(e)}")
import traceback
if self.verbose:
traceback.print_exc()
return None
def _image_to_base64(self, image_path: str) -> str:
"""
Convert image file to base64 data URL.
Args:
image_path: Path to image file
Returns:
Base64 data URL string
"""
with open(image_path, "rb") as f:
image_data = f.read()
# Determine image type from extension
ext = Path(image_path).suffix.lower()
mime_type = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp"
}.get(ext, "image/png")
base64_data = base64.b64encode(image_data).decode("utf-8")
return f"data:{mime_type};base64,{base64_data}"
def generate_image(self, prompt: str) -> Optional[bytes]:
"""
Generate an image using Nano Banana Pro.
Args:
prompt: Description of the diagram to generate
Returns:
Image bytes or None if generation failed
"""
messages = [
{
"role": "user",
"content": prompt
}
]
try:
response = self._make_request(
model=self.image_model,
messages=messages,
modalities=["image", "text"]
)
# Debug: print response structure if verbose
if self.verbose:
self._log(f"Response keys: {response.keys()}")
if "error" in response:
self._log(f"API Error: {response['error']}")
if "choices" in response and response["choices"]:
msg = response["choices"][0].get("message", {})
self._log(f"Message keys: {msg.keys()}")
# Show content preview without printing huge base64 data
content = msg.get("content", "")
if isinstance(content, str):
preview = content[:200] + "..." if len(content) > 200 else content
self._log(f"Content preview: {preview}")
elif isinstance(content, list):
self._log(f"Content is list with {len(content)} items")
for i, item in enumerate(content[:3]):
if isinstance(item, dict):
self._log(f" Item {i}: type={item.get('type')}")
# Check for API errors in response
if "error" in response:
error_msg = response["error"]
if isinstance(error_msg, dict):
error_msg = error_msg.get("message", str(error_msg))
print(f"✗ API Error: {error_msg}")
return None
image_data = self._extract_image_from_response(response)
if image_data:
self._log(f"✓ Generated image ({len(image_data)} bytes)")
else:
self._log("✗ No image data in response")
# Additional debug info when image extraction fails
if self.verbose and "choices" in response:
msg = response["choices"][0].get("message", {})
self._log(f"Full message structure: {json.dumps({k: type(v).__name__ for k, v in msg.items()})}")
return image_data
except Exception as e:
self._log(f"✗ Generation failed: {str(e)}")
import traceback
if self.verbose:
traceback.print_exc()
return None
def review_image(self, image_path: str, original_prompt: str,
iteration: int) -> Tuple[str, float]:
"""
Review generated image using AI quality analysis.
Args:
image_path: Path to the generated image
original_prompt: Original user prompt
iteration: Current iteration number
Returns:
Tuple of (critique text, quality score 0-10)
"""
# For now, use Nano Banana Pro itself for review (it has vision capabilities)
# This is more reliable than using a separate vision model
image_data_url = self._image_to_base64(image_path)
review_prompt = f"""You are reviewing a scientific diagram you just generated.
ORIGINAL REQUEST: {original_prompt}
ITERATION: {iteration}/3
Evaluate this diagram on:
1. Scientific accuracy
2. Clarity and readability
3. Label quality
4. Layout and composition
5. Professional appearance
Provide a score (0-10) and specific suggestions for improvement."""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": review_prompt
},
{
"type": "image_url",
"image_url": {
"url": image_data_url
}
}
]
}
]
try:
# Use the same Nano Banana Pro model for review (it has vision)
response = self._make_request(
model=self.image_model, # Use Nano Banana Pro for review too
messages=messages
)
# Extract text response
choices = response.get("choices", [])
if not choices:
return "Image generated successfully", 8.0
message = choices[0].get("message", {})
content = message.get("content", "")
# Check reasoning field (Nano Banana Pro puts analysis here)
reasoning = message.get("reasoning", "")
if reasoning and not content:
content = reasoning
if isinstance(content, list):
# Extract text from content blocks
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
content = "\n".join(text_parts)
# Try to extract score
score = 8.0 # Default to good score if review works
import re
score_match = re.search(r'(?:score|rating|quality)[:\s]+(\d+(?:\.\d+)?)\s*/\s*10', content, re.IGNORECASE)
if score_match:
score = float(score_match.group(1))
self._log(f"✓ Review complete (Score: {score}/10)")
return content if content else "Image generated successfully", score
except Exception as e:
self._log(f"Review skipped: {str(e)}")
# Don't fail the whole process if review fails
return "Image generated successfully (review skipped)", 8.0
def improve_prompt(self, original_prompt: str, critique: str,
iteration: int) -> str:
"""
Improve the generation prompt based on critique.
Args:
original_prompt: Original user prompt
critique: Review critique from previous iteration
iteration: Current iteration number
Returns:
Improved prompt for next generation
"""
improved_prompt = f"""{self.SCIENTIFIC_DIAGRAM_GUIDELINES}
USER REQUEST: {original_prompt}
ITERATION {iteration}: Based on previous feedback, address these specific improvements:
{critique}
Generate an improved version that addresses all the critique points while maintaining scientific accuracy and professional quality."""
return improved_prompt
def generate_iterative(self, user_prompt: str, output_path: str,
iterations: int = 3) -> Dict[str, Any]:
"""
Generate scientific schematic with iterative refinement.
Args:
user_prompt: User's description of desired diagram
output_path: Path to save final image
iterations: Number of refinement iterations (default: 3)
Returns:
Dictionary with generation results and metadata
"""
output_path = Path(output_path)
output_dir = output_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
base_name = output_path.stem
extension = output_path.suffix or ".png"
results = {
"user_prompt": user_prompt,
"iterations": [],
"final_image": None,
"final_score": 0.0,
"success": False
}
current_prompt = f"""{self.SCIENTIFIC_DIAGRAM_GUIDELINES}
USER REQUEST: {user_prompt}
Generate a publication-quality scientific diagram that meets all the guidelines above."""
print(f"\n{'='*60}")
print(f"Generating Scientific Schematic")
print(f"{'='*60}")
print(f"Description: {user_prompt}")
print(f"Iterations: {iterations}")
print(f"Output: {output_path}")
print(f"{'='*60}\n")
for i in range(1, iterations + 1):
print(f"\n[Iteration {i}/{iterations}]")
print("-" * 40)
# Generate image
print(f"Generating image...")
image_data = self.generate_image(current_prompt)
if not image_data:
print(f"✗ Generation failed")
results["iterations"].append({
"iteration": i,
"success": False,
"error": "Image generation failed"
})
continue
# Save iteration image
iter_path = output_dir / f"{base_name}_v{i}{extension}"
with open(iter_path, "wb") as f:
f.write(image_data)
print(f"✓ Saved: {iter_path}")
# Review image (skip on last iteration if desired, but we'll do it for completeness)
print(f"Reviewing image...")
critique, score = self.review_image(str(iter_path), user_prompt, i)
print(f"✓ Score: {score}/10")
# Save iteration results
iteration_result = {
"iteration": i,
"image_path": str(iter_path),
"prompt": current_prompt,
"critique": critique,
"score": score,
"success": True
}
results["iterations"].append(iteration_result)
# If this is the last iteration, we're done
if i == iterations:
results["final_image"] = str(iter_path)
results["final_score"] = score
results["success"] = True
break
# Improve prompt for next iteration
print(f"Improving prompt based on feedback...")
current_prompt = self.improve_prompt(user_prompt, critique, i + 1)
# Copy final version to output path
if results["success"] and results["final_image"]:
final_iter_path = Path(results["final_image"])
if final_iter_path != output_path:
import shutil
shutil.copy(final_iter_path, output_path)
print(f"\n✓ Final image: {output_path}")
# Save review log
log_path = output_dir / f"{base_name}_review_log.json"
with open(log_path, "w") as f:
json.dump(results, f, indent=2)
print(f"✓ Review log: {log_path}")
print(f"\n{'='*60}")
print(f"Generation Complete!")
print(f"Final Score: {results['final_score']}/10")
print(f"{'='*60}\n")
return results
def main():
"""Command-line interface."""
parser = argparse.ArgumentParser(
description="Generate scientific schematics using AI with iterative refinement",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Generate a flowchart
python generate_schematic_ai.py "CONSORT participant flow diagram" -o flowchart.png
# Generate neural network architecture
python generate_schematic_ai.py "Transformer encoder-decoder architecture" -o transformer.png
# Generate with custom iterations
python generate_schematic_ai.py "Biological signaling pathway" -o pathway.png --iterations 5
# Verbose output
python generate_schematic_ai.py "Circuit diagram" -o circuit.png -v
Environment:
OPENROUTER_API_KEY OpenRouter API key (required)
"""
)
parser.add_argument("prompt", help="Description of the diagram to generate")
parser.add_argument("-o", "--output", required=True,
help="Output image path (e.g., diagram.png)")
parser.add_argument("--iterations", type=int, default=3,
help="Number of refinement iterations (default: 3)")
parser.add_argument("--api-key", help="OpenRouter API key (or set OPENROUTER_API_KEY)")
parser.add_argument("-v", "--verbose", action="store_true",
help="Verbose output")
args = parser.parse_args()
# Check for API key
api_key = args.api_key or os.getenv("OPENROUTER_API_KEY")
if not api_key:
print("Error: OPENROUTER_API_KEY environment variable not set")
print("\nSet it with:")
print(" export OPENROUTER_API_KEY='your_api_key'")
print("\nOr provide via --api-key flag")
sys.exit(1)
# Validate iterations
if args.iterations < 1 or args.iterations > 10:
print("Error: Iterations must be between 1 and 10")
sys.exit(1)
try:
generator = ScientificSchematicGenerator(api_key=api_key, verbose=args.verbose)
results = generator.generate_iterative(
user_prompt=args.prompt,
output_path=args.output,
iterations=args.iterations
)
if results["success"]:
print(f"\n✓ Success! Image saved to: {args.output}")
sys.exit(0)
else:
print(f"\n✗ Generation failed. Check review log for details.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()