from flask import request, jsonify, make_response, g
from src import db
from src.models import StudentAssesmentResultTable, StudentTable, SchoolTable
from sqlalchemy import Select, func
from src.utils.response import error_response
from src.services.principals import find_principal_by_username
from src.services.schools import find_school_by_id
from src.services.concepts import find_concepts_by_grade_and_stream
from src.services.students import find_students_by_school_grade_section
from src.lib.bcrypt import encrypt_password, check_password
from src.lib.jwt import generate_token
from datetime import datetime, timezone


def login_func():
    body = request.get_json()
    username = body.get("username")
    password = body.get("password")

    if not username or not password:
        return error_response("Invalid credentials", 400)
    
    principal = find_principal_by_username(username)
    if not principal:
        return error_response("Invalid credentials", 401)
    
    if principal.isLoggedIn:
        if not check_password(principal.password, password):
            return error_response("Invalid credentials", 401)
        message = "Principal logged in successfully"
    else:
        principal.password = encrypt_password(password)
        principal.isLoggedIn = True
        message = "Principal logged in for the first time — password has been set"
    
    principal.last_login = datetime.now(timezone.utc)
    db.session.commit()

    token = generate_token(principal.id, "principal")
    school = find_school_by_id(principal.school_id)

    response = make_response(jsonify({
        "success": True,
        "message": message,
        "role": "principal",
        "id": principal.id,
        "username": principal.username,
        "school_id": principal.school_id,
        "firstname": principal.firstname,
        "lastname": principal.lastname,
        "school_name": school.school_name
    }), 200)

    # response.set_cookie(
    #     "auth_token",
    #     token,
    #     httponly=True,
    #     samesite="lax",
    #     path="/",
    #     max_age=7 * 24 * 60 * 60
    # )
    
    response.set_cookie(
        "auth_token",
        token,
        httponly=True,
        secure=True,             
        samesite="None",         
        path="/",
        domain=".beyondskool.ai", 
        max_age=7 * 24 * 60 * 60
    )  

    return response


def dashboard_func():
    principal = g.current_user
    school_id = principal.school_id

    body = request.get_json()
    grade = body.get("grade")
    section = body.get("section")

    if not grade or not section:
        return error_response("Grade and section required", 400)

    school = find_school_by_id(school_id)
    if not school:
        return error_response("School not found", 404)

    all_concept = find_concepts_by_grade_and_stream(grade, school.stream)
    total_students = find_students_by_school_grade_section(school_id, grade, section)

    today = datetime.now(timezone.utc)
    six_months_ago = today.replace(month=today.month-6) if today.month > 6 else today.replace(
        year=today.year-1, month=today.month+6
    )

    stmt_results = Select(StudentAssesmentResultTable).where(
        StudentAssesmentResultTable.school_id == school_id,
        StudentAssesmentResultTable.grade == grade,
        StudentAssesmentResultTable.section == section,
        StudentAssesmentResultTable.submitted_at >= six_months_ago
    )
    all_results = db.session.execute(stmt_results).scalars().all()

    stmt_assessments = Select(func.count(StudentAssesmentResultTable.id)).where(
        StudentAssesmentResultTable.school_id == school_id,
        StudentAssesmentResultTable.grade == grade,
        StudentAssesmentResultTable.section == section,
        StudentAssesmentResultTable.assesment_id.isnot(None)
    )
    total_assessments = db.session.execute(stmt_assessments).scalar()

    concept_performance = {}
    overall_correct = 0
    overall_total = 0

    for result in all_results:
        submission_month = result.submitted_at.strftime("%Y-%m")
        for concept_data in result.concept_detail:
            concept = concept_data.get("concept")
            total = concept_data.get("total", 0)
            correct = concept_data.get("correct", 0)

            overall_total += total
            overall_correct += correct

            concept_performance.setdefault(concept, {}).setdefault(
                submission_month,
                {"total_questions": 0, "correct_answers": 0}
            )

            concept_performance[concept][submission_month]["total_questions"] += total
            concept_performance[concept][submission_month]["correct_answers"] += correct

    overall_accuracy = round((overall_correct / overall_total) * 100, 2) if overall_total else 0

    month_labels, month_names = [], []
    for i in range(5, -1, -1):
        m = today.replace(month=today.month-i) if today.month > i else today.replace(
            year=today.year-1, month=today.month+12-i
        )
        month_labels.append(m.strftime("%Y-%m"))
        month_names.append(m.strftime("%b %Y"))

    concept_data = []
    for concept in all_concept:
        monthly = []
        for m in month_labels:
            data = concept_performance.get(concept.name, {}).get(m, {})
            total = data.get("total_questions", 0)
            correct = data.get("correct_answers", 0)
            accuracy = round((correct / total) * 100, 2) if total else 0
            monthly.append({"month": m, "accuracy": accuracy})

        all_data = concept_performance.get(concept.name, {}).values()
        tq = sum(d["total_questions"] for d in all_data)
        cq = sum(d["correct_answers"] for d in all_data)
        overall = round((cq / tq) * 100, 2) if tq else 0

        concept_data.append({
            "id": concept.id,
            "name": concept.name,
            "grade": concept.grade,
            "stream": concept.stream,
            "description": concept.description,
            "overall_accuracy": overall,
            "monthly_performance": monthly
        })

    concept_data.sort(key=lambda x: x["overall_accuracy"], reverse=True)

    return jsonify({
        "success": True,
        "message": "Principal dashboard data fetched successfully",
        "summary": {
            "total_students": total_students,
            "total_assessments": total_assessments,
            "overall_accuracy": overall_accuracy
        },
        "concepts": concept_data,
        "time_period": {
            "start_date": six_months_ago.strftime("%Y-%m-%d"),
            "end_date": today.strftime("%Y-%m-%d"),
            "month_labels": month_labels,
            "month_names": month_names
        }
    }), 200


def all_grades_func():
    principal = g.current_user
    school_id = principal.school_id

    stmt = Select(StudentTable.grade, StudentTable.section).where(
        StudentTable.school_id == school_id
    ).distinct()

    result = db.session.execute(stmt).all()
    if not result:
        return error_response("Grade and section are not found", 404)

    data = {}
    for grade, section in result:
        data.setdefault(grade, []).append(section)

    return jsonify({
        "success": True,
        "message": "Grades and sections fetched successfully!",
        "data": data
    }), 200


def section_progress_func():
    principal = g.current_user
    school_id = principal.school_id

    body = request.get_json()
    grade = body.get("grade") if body else None
    if not grade:
        return error_response("Grade is required", 400)

    stmt_school = Select(SchoolTable).where(SchoolTable.id == school_id)
    school = db.session.execute(stmt_school).scalar_one_or_none()
    if not school:
        return error_response("School not found", 404)

    stmt_sections = Select(StudentTable.section).where(
        StudentTable.school_id == school_id,
        StudentTable.grade == grade
    ).distinct()
    sections = [row[0] for row in db.session.execute(stmt_sections).all()]

    if not sections:
        return error_response("No sections found for this grade", 400)

    all_concepts = find_concepts_by_grade_and_stream(grade, school.stream)

    stmt_results = Select(StudentAssesmentResultTable).where(
        StudentAssesmentResultTable.school_id == school_id,
        StudentAssesmentResultTable.grade == grade
    )
    all_results = db.session.execute(stmt_results).scalars().all()

    section_analysis = []

    for section in sections:
        stmt_students = Select(func.count(StudentTable.id)).where(
            StudentTable.school_id == school_id,
            StudentTable.grade == grade,
            StudentTable.section == section
        )
        total_students = db.session.execute(stmt_students).scalar()

        section_results = [r for r in all_results if r.section == section]
        concept_summary = {}

        for r in section_results:
            for c in r.concept_detail:
                name = c.get("concept")
                concept_summary.setdefault(name, {"total": 0, "correct": 0})
                concept_summary[name]["total"] += c.get("total", 0)
                concept_summary[name]["correct"] += c.get("correct", 0)

        concept_analysis = []
        accuracy_sum = 0
        attempted = 0

        for c in all_concepts:
            data = concept_summary.get(c.name, {"total": 0, "correct": 0})
            if data["total"]:
                acc = (data["correct"] / data["total"]) * 100
                accuracy_sum += acc
                attempted += 1
            else:
                acc = 0

            concept_analysis.append({
                "concept": c.name,
                "correct_percent": round(acc, 2),
                "wrong_percent": round(100 - acc, 2) if acc else 0,
                "isAttempted": bool(data["total"])
            })

        avg_mastery = round(accuracy_sum / attempted, 2) if attempted else 0

        section_analysis.append({
            "section": section,
            "total_students": total_students,
            "assessments_submitted": len(section_results),
            "average_mastery": avg_mastery,
            "concept_analysis": concept_analysis
        })

    section_analysis.sort(key=lambda x: x["section"])

    return jsonify({
        "success": True,
        "message": "Section-wise progress fetched successfully",
        "grade": grade,
        "stream": school.stream,
        "section_analysis": section_analysis
    }), 200


