190 lines
6.6 KiB
Python
190 lines
6.6 KiB
Python
import logging
|
|
from datetime import datetime, timezone
|
|
import uuid
|
|
from fastapi import APIRouter, Request, Depends, Form
|
|
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
from sqlalchemy import select, desc
|
|
|
|
from app.models.base import async_session
|
|
from app.models.user import User
|
|
from app.models.chat import ChatMessage
|
|
from app.models.workout import Workout
|
|
from app.models.checkin import Checkin
|
|
from app.models.measurement import Measurement, MeasurementType
|
|
from app.auth import get_current_user
|
|
from app.services.opencode_proxy import query_opencode, OpenCodeUnavailableError
|
|
|
|
router = APIRouter()
|
|
templates = Jinja2Templates(directory="app/templates")
|
|
|
|
ONBOARDING_PROMPT = (
|
|
"I am a new user and this is the first interaction with this application. "
|
|
"Please review my goals, equipment, medical history, and stats (all in your context). "
|
|
"Create a high-level training plan that you can use to make weekly plans and daily workouts. "
|
|
"Ask for more clarity whenever you need it. "
|
|
"When the workout planning is ready for me to begin, prompt me to switch to the Workouts or Dashboard tab."
|
|
)
|
|
|
|
|
|
logger = logging.getLogger("chat")
|
|
|
|
|
|
@router.get("/api/chat/messages")
|
|
async def get_chat_messages(request: Request, user: User = Depends(get_current_user)):
|
|
session_id = request.cookies.get("chat_session_id")
|
|
if not session_id:
|
|
return JSONResponse([])
|
|
async with async_session() as session:
|
|
result = await session.execute(
|
|
select(ChatMessage)
|
|
.where(
|
|
ChatMessage.user_id == user.id,
|
|
ChatMessage.session_id == session_id,
|
|
)
|
|
.order_by(desc(ChatMessage.created_at))
|
|
.limit(15)
|
|
)
|
|
rows = result.scalars().all()
|
|
rows = list(reversed(rows))
|
|
return JSONResponse([
|
|
{"role": m.role, "content": m.content, "created_at": m.created_at} for m in rows
|
|
])
|
|
|
|
|
|
@router.get("/chat", response_class=HTMLResponse)
|
|
async def chat_page(request: Request, user: User = Depends(get_current_user), first: int = 0):
|
|
session_id = request.cookies.get("chat_session_id")
|
|
if not session_id:
|
|
session_id = str(uuid.uuid4())
|
|
|
|
async with async_session() as session:
|
|
result = await session.execute(
|
|
select(ChatMessage)
|
|
.where(
|
|
ChatMessage.user_id == user.id,
|
|
ChatMessage.session_id == session_id,
|
|
)
|
|
.order_by(ChatMessage.created_at)
|
|
)
|
|
messages = result.scalars().all()
|
|
|
|
resp = templates.TemplateResponse(request, "chat.html", {
|
|
"user": user,
|
|
"messages": messages,
|
|
"session_id": session_id,
|
|
"first": first,
|
|
"onboarding_prompt": ONBOARDING_PROMPT,
|
|
})
|
|
resp.set_cookie(key="chat_session_id", value=session_id, httponly=True, max_age=86400 * 30)
|
|
return resp
|
|
|
|
|
|
@router.post("/chat")
|
|
async def chat_send(
|
|
request: Request,
|
|
user: User = Depends(get_current_user),
|
|
message: str = Form(),
|
|
):
|
|
session_id = request.cookies.get("chat_session_id") or str(uuid.uuid4())
|
|
|
|
async with async_session() as session:
|
|
result = await session.execute(
|
|
select(Workout)
|
|
.where(Workout.user_id == user.id)
|
|
.order_by(desc(Workout.date))
|
|
.limit(5)
|
|
)
|
|
recent_workouts = result.scalars().all()
|
|
|
|
result = await session.execute(
|
|
select(Checkin)
|
|
.where(Checkin.user_id == user.id)
|
|
.order_by(desc(Checkin.date))
|
|
.limit(5)
|
|
)
|
|
recent_checkins = result.scalars().all()
|
|
|
|
result = await session.execute(
|
|
select(Measurement)
|
|
.join(MeasurementType, Measurement.measurement_type_id == MeasurementType.id)
|
|
.where(
|
|
Measurement.user_id == user.id,
|
|
MeasurementType.name == "Weight",
|
|
)
|
|
.order_by(desc(Measurement.date))
|
|
.limit(1)
|
|
)
|
|
latest_weight = result.scalar_one_or_none()
|
|
|
|
workout_lines = []
|
|
for w in recent_workouts:
|
|
workout_lines.append(f" {w.date} — {w.name} ({w.status})")
|
|
checkin_lines = []
|
|
for c in recent_checkins:
|
|
parts = []
|
|
if c.feeling:
|
|
parts.append(f"feeling={c.feeling}")
|
|
if c.weight_lb:
|
|
parts.append(f"weight={c.weight_lb}lb")
|
|
if c.calories:
|
|
parts.append(f"cal(yesterday)={c.calories}")
|
|
if c.steps:
|
|
parts.append(f"steps(yesterday)={c.steps}")
|
|
if c.sleep_hours:
|
|
parts.append(f"sleep={c.sleep_hours}h")
|
|
checkin_lines.append(f" {c.date} — {' | '.join(parts)}")
|
|
|
|
weight_str = f"{latest_weight.value} lb" if latest_weight else "Not recorded"
|
|
|
|
user_context = (
|
|
f"Username: {user.username}. "
|
|
f"Weight: {weight_str}. "
|
|
f"Goals: {user.goals or 'Not specified'}. "
|
|
f"Equipment: {user.equipment or 'Not specified'}. "
|
|
f"Medical: {user.medical_notes or 'None'}. "
|
|
f"Vital stats: {user.vital_stats or 'Not specified'}. "
|
|
)
|
|
if recent_workouts:
|
|
user_context += "Recent workouts:\n" + "\n".join(workout_lines) + ". "
|
|
if recent_checkins:
|
|
user_context += "Recent check-ins:\n" + "\n".join(checkin_lines) + ". "
|
|
|
|
async def stream():
|
|
async with async_session() as session:
|
|
user_msg = ChatMessage(
|
|
user_id=user.id,
|
|
session_id=session_id,
|
|
role="user",
|
|
content=message,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
session.add(user_msg)
|
|
await session.commit()
|
|
|
|
assistant_content = ""
|
|
try:
|
|
async for chunk in query_opencode(message, session_id, user_context):
|
|
assistant_content += chunk
|
|
yield f"data: {chunk}\n\n"
|
|
except OpenCodeUnavailableError as e:
|
|
logger.error("Chat failed for user %s: %s", user.username, e)
|
|
yield f"data: [error] {e}\n\n"
|
|
# Don't save an error assistant message
|
|
return
|
|
|
|
async with async_session() as session:
|
|
assistant_msg = ChatMessage(
|
|
user_id=user.id,
|
|
session_id=session_id,
|
|
role="assistant",
|
|
content=assistant_content,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
session.add(assistant_msg)
|
|
await session.commit()
|
|
|
|
resp = StreamingResponse(stream(), media_type="text/event-stream")
|
|
resp.set_cookie(key="chat_session_id", value=session_id, httponly=True, max_age=86400 * 30)
|
|
return resp
|