Files

100 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Ollama へのプロンプト送信と JSON 抽出。"""
import json
import re
import httpx
from config import OLLAMA_URL, OLLAMA_MODEL
_SYSTEM_PROMPT = """\
あなたはWindmillフロー生成AIです。
以下のルールを必ず守ってください:
- JSONのみ出力すること
- Markdownのコードブロック```)は使わない
- 説明文・コメントは一切出力しない
- フィールド順は必ず summary → value の順にすること
- 出力するJSONは必ず以下のスキーマに従うこと
{
"summary": "<タスクを一言で表す英語の説明>",
"value": {
"modules": [
{
"id": "a",
"value": {
"type": "rawscript",
"language": "python3",
"content": "<タスクを実行するPython3コード>",
"input_transforms": {}
}
}
]
}
}
【必須ルール】
- content のコードは必ず def main(): で始めることWindmillのエントリーポイント
- main() がない場合は AttributeError になるため絶対に省略しないこと
- content の内容はユーザーのタスク説明に従って書くこと(テンプレートをそのままコピーしないこと)
- content 内の改行は \\n でエスケープすること(リテラル改行を入れると JSON パースエラーになる)
- modules.id は a, b, c... の連番。追加フィールド禁止。
【出力例1】タスク: 「おはよう」と表示する
{"summary":"Print greeting","value":{"modules":[{"id":"a","value":{"type":"rawscript","language":"python3","content":"def main():\\n print('おはよう')","input_transforms":{}}}]}}
【出力例2】タスク: 1から5までの数字を表示する
{"summary":"Print numbers 1 to 5","value":{"modules":[{"id":"a","value":{"type":"rawscript","language":"python3","content":"def main():\\n for i in range(1, 6):\\n print(i)","input_transforms":{}}}]}}
【出力例3】タスク: 現在の日時を表示する
{"summary":"Display current datetime","value":{"modules":[{"id":"a","value":{"type":"rawscript","language":"python3","content":"def main():\\n from datetime import datetime\\n print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))","input_transforms":{}}}]}}\
"""
def _chat(messages: list[dict]) -> str:
resp = httpx.post(
f"{OLLAMA_URL}/api/chat",
json={
"model": OLLAMA_MODEL,
"messages": messages,
"stream": False,
"options": {"temperature": 0.1, "top_p": 0.9},
},
timeout=120,
)
resp.raise_for_status()
raw = resp.json()["message"]["content"].strip()
return _extract_json(raw)
def _extract_json(raw: str) -> str:
"""LLM がコードブロックで囲んでしまった場合でも JSON 部分を取り出す。"""
# ```json ... ``` または ``` ... ``` を除去
match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", raw)
if match:
return match.group(1).strip()
return raw
def generate_flow(task_description: str) -> str:
"""初回生成:タスク説明からフロー JSON を生成する。"""
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": f"以下のフローをJSON形式で生成してください。\n要件: {task_description}"},
]
return _chat(messages)
def fix_flow(previous_flow_json: str, error_log: str) -> str:
"""リトライ生成:前回の JSON + エラーログから修正版を生成する。"""
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": (
"前回のフロー実行でエラーが発生しました。修正したフローをJSON形式で出力してください。\n\n"
f"--- 前回のフローJSON ---\n{previous_flow_json}\n\n"
f"--- エラーログ ---\n{error_log}\n\n"
"--- 修正指示 ---\n"
"- 前回と同一のJSONは絶対に出力しないこと\n"
"- エラーの原因箇所のみ修正すること\n"
"- スキーマは変えないこと"
)},
]
return _chat(messages)