100 lines
4.1 KiB
Python
100 lines
4.1 KiB
Python
"""Ollama へのプロンプト送信と JSON 抽出。"""
|
||
import json
|
||
import re
|
||
import httpx
|
||
from config import OLLAMA_URL, OLLAMA_MODEL
|
||
|
||
_SYSTEM_PROMPT = """\
|
||
あなたはWindmillフロー生成AIです。
|
||
以下のルールを必ず守ってください:
|
||
- JSONのみ出力すること
|
||
- Markdownのコードブロック(```)は使わない
|
||
- 説明文・コメントは一切出力しない
|
||
- フィールド順は必ず summary → value の順にすること
|
||
- 出力するJSONは必ず以下のスキーマに従うこと:
|
||
|
||
{
|
||
"summary": "<タスクを一言で表す英語の説明>",
|
||
"value": {
|
||
"modules": [
|
||
{
|
||
"id": "a",
|
||
"value": {
|
||
"type": "rawscript",
|
||
"language": "python3",
|
||
"content": "<タスクを実行するPython3コード>",
|
||
"input_transforms": {}
|
||
}
|
||
}
|
||
]
|
||
}
|
||
}
|
||
|
||
【必須ルール】
|
||
- content のコードは必ず def main(): で始めること(Windmillのエントリーポイント)
|
||
- main() がない場合は AttributeError になるため絶対に省略しないこと
|
||
- content の内容はユーザーのタスク説明に従って書くこと(テンプレートをそのままコピーしないこと)
|
||
- content 内の改行は \\n でエスケープすること(リテラル改行を入れると JSON パースエラーになる)
|
||
- modules.id は a, b, c... の連番。追加フィールド禁止。
|
||
|
||
【出力例1】タスク: 「おはよう」と表示する
|
||
{"summary":"Print greeting","value":{"modules":[{"id":"a","value":{"type":"rawscript","language":"python3","content":"def main():\\n print('おはよう')","input_transforms":{}}}]}}
|
||
|
||
【出力例2】タスク: 1から5までの数字を表示する
|
||
{"summary":"Print numbers 1 to 5","value":{"modules":[{"id":"a","value":{"type":"rawscript","language":"python3","content":"def main():\\n for i in range(1, 6):\\n print(i)","input_transforms":{}}}]}}
|
||
|
||
【出力例3】タスク: 現在の日時を表示する
|
||
{"summary":"Display current datetime","value":{"modules":[{"id":"a","value":{"type":"rawscript","language":"python3","content":"def main():\\n from datetime import datetime\\n print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))","input_transforms":{}}}]}}\
|
||
"""
|
||
|
||
|
||
def _chat(messages: list[dict]) -> str:
|
||
resp = httpx.post(
|
||
f"{OLLAMA_URL}/api/chat",
|
||
json={
|
||
"model": OLLAMA_MODEL,
|
||
"messages": messages,
|
||
"stream": False,
|
||
"options": {"temperature": 0.1, "top_p": 0.9},
|
||
},
|
||
timeout=120,
|
||
)
|
||
resp.raise_for_status()
|
||
raw = resp.json()["message"]["content"].strip()
|
||
return _extract_json(raw)
|
||
|
||
|
||
def _extract_json(raw: str) -> str:
|
||
"""LLM がコードブロックで囲んでしまった場合でも JSON 部分を取り出す。"""
|
||
# ```json ... ``` または ``` ... ``` を除去
|
||
match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", raw)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return raw
|
||
|
||
|
||
def generate_flow(task_description: str) -> str:
|
||
"""初回生成:タスク説明からフロー JSON を生成する。"""
|
||
messages = [
|
||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||
{"role": "user", "content": f"以下のフローをJSON形式で生成してください。\n要件: {task_description}"},
|
||
]
|
||
return _chat(messages)
|
||
|
||
|
||
def fix_flow(previous_flow_json: str, error_log: str) -> str:
|
||
"""リトライ生成:前回の JSON + エラーログから修正版を生成する。"""
|
||
messages = [
|
||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||
{"role": "user", "content": (
|
||
"前回のフロー実行でエラーが発生しました。修正したフローをJSON形式で出力してください。\n\n"
|
||
f"--- 前回のフローJSON ---\n{previous_flow_json}\n\n"
|
||
f"--- エラーログ ---\n{error_log}\n\n"
|
||
"--- 修正指示 ---\n"
|
||
"- 前回と同一のJSONは絶対に出力しないこと\n"
|
||
"- エラーの原因箇所のみ修正すること\n"
|
||
"- スキーマは変えないこと"
|
||
)},
|
||
]
|
||
return _chat(messages)
|