from flask import request, jsonify, make_response, g, send_file
from datetime import datetime, timezone
from src.services.teachers import find_teacher_by_username
from src.lib.bcrypt import encrypt_password, check_password
from src import db
from src.lib.jwt import generate_token
from src.utils.response import error_response
import openpyxl
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from sqlalchemy import Select, func
import random
from src.services.schools import find_school_by_id
from src.utils.random import randomize_correct_options
from io import BytesIO
from src.utils.remedial import get_concept_weak_students, get_overall_weak_students, get_subconcept_weak_students
from src.models import (
    StudentAssesmentResultTable,
    ConceptTable,
    SubConceptTable,
    QuestionTable,
    TeacherAssessmentTable,
    StudentTable
)


def delete_question_func():
    teacher = g.current_user
    body = request.get_json()
    
    assessment_id = body.get("assessment_id")
    question_id = body.get("question_id")
    
    if not teacher:
        return error_response("No active teacher found", 401)
    
    if not assessment_id or not question_id:
        return error_response("Assessment ID and Question ID required", 400)
    
    # Fetch assessment
    assessment = db.session.execute(
        Select(TeacherAssessmentTable).where(
            TeacherAssessmentTable.id == assessment_id,
            TeacherAssessmentTable.created_by == teacher.id
        )
    ).scalar_one_or_none()
    
    if not assessment:
        return error_response("Assessment not found", 404)
    
    # Remove question from questions array
    updated_questions = [q for q in assessment.questions if q.get("id") != question_id]
    
    if len(updated_questions) == len(assessment.questions):
        return error_response("Question not found in assessment", 404)
    
    assessment.questions = updated_questions
    db.session.commit()
    
    return jsonify({
        "success": True,
        "message": "Question deleted successfully",
        "total_questions": len(updated_questions),
        "question": updated_questions
    }), 200


def replace_question_func():
    from sqlalchemy.orm.attributes import flag_modified  # ← Import add karo
    
    teacher = g.current_user
    body = request.get_json()
    
    assessment_id = body.get("assessment_id")
    question_id = body.get("question_id")
    
    if not teacher:
        return error_response("No active teacher found", 401)
    
    if not assessment_id or not question_id:
        return error_response("Assessment ID and Question ID required", 400)
    
    assessment = db.session.execute(
        Select(TeacherAssessmentTable).where(
            TeacherAssessmentTable.id == assessment_id,
            TeacherAssessmentTable.created_by == teacher.id
        )
    ).scalar_one_or_none()
    
    if not assessment:
        return error_response("Assessment not found", 404)
    
    question_to_replace = None
    question_index = None
    
    for idx, q in enumerate(assessment.questions):
        if q.get("id") == question_id:
            question_to_replace = q
            question_index = idx
            break
    
    if not question_to_replace:
        return error_response("Question not found in assessment", 404)
    
    concept_name = question_to_replace.get("concept")
    subconcept_name = question_to_replace.get("subconcept")
    
    concept_db = db.session.execute(
        Select(ConceptTable).where(
            ConceptTable.name == concept_name,
            ConceptTable.grade == assessment.grade,
            ConceptTable.stream == teacher.stream
        )
    ).scalar_one_or_none()
    
    if not concept_db:
        return error_response("Concept not found", 404)
    
    subconcept_db = db.session.execute(
        Select(SubConceptTable).where(
            SubConceptTable.concept_id == concept_db.id,
            SubConceptTable.name == subconcept_name
        )
    ).scalar_one_or_none()
    
    if not subconcept_db:
        return error_response("Subconcept not found", 404)
    
    all_questions = db.session.execute(
        Select(QuestionTable).where(
            QuestionTable.subconcept_id == subconcept_db.id
        )
    ).scalars().all()
    
    used_question_ids = [q.get("id") for q in assessment.questions]
    available_questions = [q for q in all_questions if q.id not in used_question_ids]
    
    if not available_questions:
        return error_response("No more questions available for this subconcept", 404)
    
    new_question = random.choice(available_questions)
    
    options_array = []
    correct_answer_text = ""
    
    if isinstance(new_question.options, dict):
        for key in sorted(new_question.options.keys()):
            option_value = new_question.options[key]
            options_array.append(option_value)
            if key == new_question.correct_answer:
                correct_answer_text = option_value
    
    if not correct_answer_text:
        correct_answer_text = new_question.correct_answer
    
    new_question_data = {
        "id": new_question.id,
        "question": new_question.question,
        "options": options_array,
        "answer": correct_answer_text,
        "concept": concept_name,
        "subconcept": subconcept_name,
        "difficulty": new_question.difficulty
    }
    
    new_question_data = randomize_correct_options([new_question_data])[0]
    
    # ✅ CRITICAL FIX - Ye 3 lines add karo
    assessment.questions[question_index] = new_question_data
    flag_modified(assessment, "questions")  # ← SQLAlchemy ko force karo
    db.session.flush()  # ← Changes immediately apply karo
    
    db.session.commit()
    
    return jsonify({
        "success": True,
        "message": "Question replaced successfully",
        "question": assessment.questions,  # ← Complete updated array
        "total_questions": len(assessment.questions)
    }), 200


def verify_assessment_func():
    teacher = g.current_user
    body = request.get_json()
    
    assessment_id = body.get("assessment_id")
    
    if not teacher:
        return error_response("No active teacher found", 401)
    
    if not assessment_id:
        return error_response("Assessment ID required", 400)
    
    # Fetch assessment
    assessment = db.session.execute(
        Select(TeacherAssessmentTable).where(
            TeacherAssessmentTable.id == assessment_id,
            TeacherAssessmentTable.created_by == teacher.id
        )
    ).scalar_one_or_none()
    
    if not assessment:
        return error_response("Assessment not found", 404)
    
    # Verify assessment
    assessment.is_verified = True
    db.session.commit()
    
    return jsonify({
        "success": True,
        "message": "Assessment verified and sent to students successfully"
    }), 200


def login_func():
    body = request.get_json() or {}
    username = body.get("username")
    password = body.get("password")

    if not username or not password:
        return error_response("Invalid credentials", 400)

    teacher = find_teacher_by_username(username)
    if not teacher:
        return error_response("Invalid credentials", 401)

    if teacher.isLoggedIn:
        if not check_password(teacher.password, password):
            return error_response("Invalid credentials", 401)
    else:
        teacher.password = encrypt_password(password)
        teacher.isLoggedIn = True

    teacher.last_login = datetime.now(timezone.utc)
    db.session.commit()

    token = generate_token(teacher.id, "teacher")
    school = find_school_by_id(teacher.school_id)

    response = make_response(jsonify({
        "success": True,
        "message": "Teacher logged in successfully",
        "user": {
            "id": teacher.id,
            "username": teacher.username,
            "school_name": school.school_name,
            "grades": teacher.grades,
            "stream": school.stream
        },
        "role": "teacher"
    }), 200)

    # response.set_cookie(
    #     "auth_token",
    #     token,
    #     httponly=True,
    #     samesite="lax",
    #     path="/",
    #     max_age=7 * 24 * 60 * 60
    # )
    
    response.set_cookie(
        "auth_token",
        token,
        httponly=True,
        secure=True,             
        samesite="None",         
        path="/",
        domain=".beyondskool.ai", 
        max_age=7 * 24 * 60 * 60
    )  

    return response


def auto_ass_func():
    teacher = g.current_user
    body = request.get_json() or {}

    grade = str(body.get("grade"))
    section = body.get("section")
    stream = teacher.stream

    if not teacher:
        return error_response("No active teacher found", 404)

    if not grade or not section:
        return error_response("Grade and section required", 400)

    school_id = teacher.school_id

    # Step 1: fetch student results
    student_results = db.session.query(StudentAssesmentResultTable).filter(
        StudentAssesmentResultTable.school_id == school_id,
        StudentAssesmentResultTable.grade == grade,
        StudentAssesmentResultTable.section == section
    ).all()

    if not student_results:
        return jsonify({
            "success": False,
            "message": f"No student results found for Grade {grade} Section {section}"
        }), 404

    # Step 2: analyze concept performance
    concept_stats = {}

    for result in student_results:

        if result.concept_detail:
            for concept_obj in result.concept_detail:
                concept_name = concept_obj.get("concept") or concept_obj.get("conceptname")
                total = concept_obj.get("total", 0)
                correct = concept_obj.get("correct", 0)

                if not concept_name or total == 0:
                    continue

                concept_stats.setdefault(concept_name, {
                    "total_questions": 0,
                    "total_correct": 0
                })

                concept_stats[concept_name]["total_questions"] += total
                concept_stats[concept_name]["total_correct"] += correct

        if result.subconcept_detail:
            for subconcept_obj in result.subconcept_detail:
                concept_name = subconcept_obj.get("concept")
                total = subconcept_obj.get("total", 0)
                correct = subconcept_obj.get("correct", 0)

                if not concept_name or total == 0:
                    continue

                concept_stats.setdefault(concept_name, {
                    "total_questions": 0,
                    "total_correct": 0
                })

                concept_stats[concept_name]["total_questions"] += total
                concept_stats[concept_name]["total_correct"] += correct

    if not concept_stats:
        return error_response("No concept performance data available", 404)

    # Step 3: weakest concepts
    concept_performance = []

    for concept_name, stats in concept_stats.items():
        accuracy = (stats["total_correct"] / stats["total_questions"]) * 100
        concept_performance.append({
            "concept": concept_name,
            "accuracy": round(accuracy, 2)
        })

    concept_performance.sort(key=lambda x: x["accuracy"])
    weakest_concepts = [c["concept"] for c in concept_performance[:5]]

    if not weakest_concepts:
        return error_response("Unable to identify weak concepts", 404)

    # Step 4: fetch subconcepts
    all_subconcepts = []
    concept_names = []

    for concept_name in weakest_concepts:
        concept_db = db.session.execute(
            Select(ConceptTable).where(
                ConceptTable.name == concept_name,
                ConceptTable.grade == grade,
                ConceptTable.stream == stream
            )
        ).scalar_one_or_none()

        if not concept_db:
            continue

        concept_names.append(concept_name)

        subconcepts = db.session.execute(
            Select(SubConceptTable).where(
                SubConceptTable.concept_id == concept_db.id
            )
        ).scalars().all()

        for sc in subconcepts:
            all_subconcepts.append({
                "subconcept_id": sc.id,
                "concept_name": concept_name,
                "subconcept_name": sc.name
            })

    if not all_subconcepts:
        return error_response("No subconcepts found for weak concepts", 404)

    # Step 5: pick questions
    total_questions = random.randint(15, 20)
    questions_per_subconcept = max(1, total_questions // len(all_subconcepts))

    transformed_questions = []

    for sc in all_subconcepts:
        questions = db.session.execute(
            Select(QuestionTable).where(
                QuestionTable.subconcept_id == sc["subconcept_id"]
            )
        ).scalars().all()

        picked = random.sample(questions, min(len(questions), questions_per_subconcept))

        for q in picked:
            options = []
            correct_answer = ""

            if isinstance(q.options, dict):
                for k in sorted(q.options.keys()):
                    options.append(q.options[k])
                    if k == q.correct_answer:
                        correct_answer = q.options[k]

            transformed_questions.append({
                "id": q.id,
                "question": q.question,
                "options": options,
                "answer": correct_answer or q.correct_answer,
                "concept": sc["concept_name"],
                "subconcept": sc["subconcept_name"],
                "difficulty": q.difficulty
            })

    if not transformed_questions:
        return error_response("No questions found for weak concepts", 404)

    random.shuffle(transformed_questions)
    transformed_questions = randomize_correct_options(transformed_questions)

    # Step 6: save assessment
    new_assessment = TeacherAssessmentTable(
        school_id=school_id,
        created_by=teacher.id,
        grade=grade,
        section=section,
        concepts=concept_names,
        questions=transformed_questions
    )

    db.session.add(new_assessment)
    db.session.commit()

    return jsonify({
        "success": True,
        "message": "Auto-generated assessment sent successfully",
        "assessment_id": new_assessment.id,
        "total_questions": len(transformed_questions),
        "weakest_concepts": weakest_concepts,
        "concept_performance": concept_performance[:5],
        "question": transformed_questions
    }), 200


def send_ass_func():
    teacher = g.current_user
    body = request.get_json()    
    concepts_with_sub = body.get("concepts")
    grade = str(body.get("grade"))
    section = body.get("section")

    if not teacher:
        return error_response("No active teacher found", 404)
        
    if not concepts_with_sub or not grade or not section:
        return error_response("Required field is missing", 400)
        
    subconcept_ids = []
    concept_names = []
    subconcept_mapping = {}
    
    for concept in concepts_with_sub:
        concept_id = concept.get("concept_id")
        concept_name = concept.get("concept")
        concept_names.append(concept_name)
        
        for sc_name in concept.get("subconcept", []):
            stmt_find_subconcept = Select(SubConceptTable).where(
                SubConceptTable.concept_id == concept_id,
                SubConceptTable.name == sc_name
            )
            
            result = db.session.execute(stmt_find_subconcept).scalars().all()
            for sub in result:
                subconcept_ids.append(sub.id)
                subconcept_mapping[sub.id] = (concept_name, sc_name)
    
    if not subconcept_ids:
        return error_response("No valid subconcepts found", 400)
        
    total_questions = random.randint(15, 25)
    questions_per_subconcept = max(1, total_questions // len(subconcept_ids))
    transformed_questions = []
        
    for sc_id in subconcept_ids:
        stmt_all_ques = Select(QuestionTable).where(
            QuestionTable.subconcept_id == sc_id
        ) 
        all_ques = db.session.execute(stmt_all_ques).scalars().all()
            
        if not all_ques:
            continue
                
        picked = random.sample(all_ques, min(len(all_ques), questions_per_subconcept))
            
        for q in picked:
            concept_name, subconcept_name = subconcept_mapping[sc_id]
                
                # Build options array and get correct answer text
            options_array = []
            correct_answer_text = ""
                
            if isinstance(q.options, dict):
                for key in sorted(q.options.keys()):
                    option_value = q.options[key]
                    options_array.append(option_value)
                    # If this key matches correct_answer, store the actual text
                    if key == q.correct_answer:
                        correct_answer_text = option_value
                
            # Fallback: if correct_answer_text is still empty, use q.correct_answer as-is
            if not correct_answer_text:
                correct_answer_text = q.correct_answer
                
            transformed_q = {
                    "id": q.id,
                    "question": q.question,
                    "options": options_array,
                    "answer": correct_answer_text,
                    "concept": concept_name,
                    "subconcept": subconcept_name,
                    "difficulty": q.difficulty
                }
            transformed_questions.append(transformed_q)
        
    if not transformed_questions:
        return error_response("No questions found for selected subconcepts", 404)
        
    # Randomize correct answer positions to avoid patterns
    transformed_questions = randomize_correct_options(transformed_questions)
        
    new_assessment = TeacherAssessmentTable(
            school_id=teacher.school_id,
            created_by=teacher.id,
            grade=grade,
            section=section,
            concepts=concept_names,
            questions=transformed_questions,
        )
        
    db.session.add(new_assessment)
    db.session.commit()
        
    return jsonify({
        "success": True,
        "message": "Assessment sent to students successfully",
        "concept": concept_names,
        "question": transformed_questions,
        "total_questions": len(transformed_questions),
        "total_subconcepts": len(subconcept_ids),
        "questions_per_subconcept": questions_per_subconcept,
        "assessment_id": new_assessment.id,
    }), 200


def overview_func():
    teacher = g.current_user
    body = request.get_json()

    if not teacher:
        return error_response("No active teacher found", 401)

    teacher_id = teacher.id
    school_id = teacher.school_id
    grade = body.get("grade")
    section = body.get("section")
    stream = teacher.stream

    # Fetch students
    stmt_students = Select(StudentTable).where(
        (StudentTable.school_id == school_id) &
        (StudentTable.grade == grade) &
        (StudentTable.section == section)
    )
    students = db.session.execute(stmt_students).scalars().all()
    total_students = len(students)

    # Fetch assessments
    stmt_assessments = Select(TeacherAssessmentTable).where(
        (TeacherAssessmentTable.created_by == teacher_id) &
        (TeacherAssessmentTable.grade == grade) &
        (TeacherAssessmentTable.section == section)
    )
    assessments = db.session.execute(stmt_assessments).scalars().all()
    assessment_ids = [a.id for a in assessments]

    # Extract tested concepts
    tested_concepts = []
    for a in assessments:
        for c in a.concepts:
            if c not in tested_concepts:
                tested_concepts.append(c)

    # Fetch all concepts
    stmt_concepts = Select(ConceptTable).where(
        (ConceptTable.grade == str(grade)) &
        (ConceptTable.stream == stream)
    )
    all_concepts = db.session.execute(stmt_concepts).scalars().all()
    all_concept_names = sorted([c.name for c in all_concepts])

    # Final concept order
    tested_concepts_sorted = sorted(list(set(tested_concepts)))
    remaining_concepts_sorted = [
        c for c in all_concept_names if c not in tested_concepts_sorted
    ]
    final_concept_order = tested_concepts_sorted + remaining_concepts_sorted

    # Submissions
    stmt_submissions_total = Select(StudentAssesmentResultTable).where(
        (StudentAssesmentResultTable.school_id == school_id) &
        (StudentAssesmentResultTable.grade == grade) &
        (StudentAssesmentResultTable.section == section) &
        (StudentAssesmentResultTable.assesment_id.in_(assessment_ids))
    )

    stmt_submissions = Select(StudentAssesmentResultTable).where(
        (StudentAssesmentResultTable.school_id == school_id) &
        (StudentAssesmentResultTable.grade == grade) &
        (StudentAssesmentResultTable.section == section)
    )

    submissions_len = db.session.execute(stmt_submissions_total).scalars().all()
    submissions = db.session.execute(stmt_submissions).scalars().all()
    total_assessment_submitted = len(submissions_len)

    # Student analysis
    students_analysis = []

    for student in students:
        student_results = [
            r for r in submissions if r.student_id == student.id
        ]

        concept_summary = {}
        attempted_concepts = set()  # ← Track which concepts were actually attempted

        for result in student_results:
            for concept_data in result.concept_detail:
                name = concept_data.get("concept")
                total = concept_data.get("total", 0)
                correct = concept_data.get("correct", 0)

                if name:
                    attempted_concepts.add(name)  # ← Mark this concept as attempted
                
                if name not in concept_summary:
                    concept_summary[name] = {"total": 0, "correct": 0}

                concept_summary[name]["total"] += total
                concept_summary[name]["correct"] += correct

        concept_analysis = []
        total_accuracy_sum = 0

        for concept in final_concept_order:
            data = concept_summary.get(
                concept, {"total": 0, "correct": 0}
            )

            total = data["total"]
            correct = data["correct"]

            # ✅ Check if concept was actually attempted by this student
            is_attempted = concept in attempted_concepts

            if total > 0:
                accuracy = (correct / total) * 100
            else:
                accuracy = 0

            wrong = 100 - accuracy
            
            # Only add to average if attempted
            if is_attempted:
                total_accuracy_sum += accuracy

            concept_analysis.append({
                "concept": concept,
                "correct_percent": round(accuracy, 2),
                "wrong_percent": round(wrong, 2),
                "isAttempted": is_attempted
            })

        # Calculate average only from attempted concepts
        attempted_count = sum(1 for c in concept_analysis if c["isAttempted"])
        avg_mastery = (
            round(total_accuracy_sum / attempted_count, 2)
            if attempted_count > 0 else 0
        )

        students_analysis.append({
            "student_id": student.id,
            "student_name": f"{student.firstname} {student.lastname}",
            "average_mastery": avg_mastery,
            "concept_analysis": concept_analysis
        })

    return jsonify({
        "success": True,
        "message": "Teacher overview fetched successfully",
        "total_students": total_students,
        "total_assessment_submitted": total_assessment_submitted,
        "tested_concepts": tested_concepts_sorted,
        "remaining_concepts": remaining_concepts_sorted,
        "students_analysis": students_analysis,
        "concept_order": final_concept_order
    }), 200


def download_func():
    teacher = g.current_user
    body = request.get_json()

    if not teacher:
        return error_response("No active teacher found", 401)

    teacher_id = teacher.id
    school_id = teacher.school_id
    grade = body.get("grade")
    section = body.get("section")
    stream = teacher.stream

    # Fetch students
    stmt_students = Select(StudentTable).where(
        (StudentTable.school_id == school_id) &
        (StudentTable.grade == grade) &
        (StudentTable.section == section)
    )
    students = db.session.execute(stmt_students).scalars().all()

    # Fetch assessments
    stmt_assessments = Select(TeacherAssessmentTable).where(
        (TeacherAssessmentTable.created_by == teacher_id) &
        (TeacherAssessmentTable.grade == grade) &
        (TeacherAssessmentTable.section == section)
    )
    assessments = db.session.execute(stmt_assessments).scalars().all()
    assessment_ids = [a.id for a in assessments]

    # Extract tested concepts
    tested_concepts = []
    for a in assessments:
        for c in a.concepts:
            if c not in tested_concepts:
                tested_concepts.append(c)

    # Fetch all concepts
    stmt_concepts = Select(ConceptTable).where(
        (ConceptTable.grade == str(grade)) &
        (ConceptTable.stream == stream)
    )
    all_concepts = db.session.execute(stmt_concepts).scalars().all()
    all_concept_names = sorted([c.name for c in all_concepts])

    # Final concept order
    tested_concepts_sorted = sorted(list(set(tested_concepts)))
    remaining_concepts_sorted = [
        c for c in all_concept_names if c not in tested_concepts_sorted
    ]
    final_concept_order = tested_concepts_sorted + remaining_concepts_sorted

    # Submissions
    stmt_submissions = Select(StudentAssesmentResultTable).where(
        (StudentAssesmentResultTable.school_id == school_id) &
        (StudentAssesmentResultTable.grade == grade) &
        (StudentAssesmentResultTable.section == section)
    )
    submissions = db.session.execute(stmt_submissions).scalars().all()

    # Student analysis
    students_data = []

    for student in students:
        student_results = [
            r for r in submissions if r.student_id == student.id
        ]

        concept_summary = {}

        for result in student_results:
            for concept_data in result.concept_detail:
                name = concept_data.get("concept")
                total = concept_data.get("total", 0)
                correct = concept_data.get("correct", 0)

                if name not in concept_summary:
                    concept_summary[name] = {"total": 0, "correct": 0}

                concept_summary[name]["total"] += total
                concept_summary[name]["correct"] += correct

        concept_analysis = {}
        total_accuracy_sum = 0

        for concept in final_concept_order:
            data = concept_summary.get(concept, {"total": 0, "correct": 0})
            total = data["total"]
            correct = data["correct"]

            if total > 0:
                accuracy = (correct / total) * 100
                is_attempted = True
            else:
                accuracy = 0
                is_attempted = False

            total_accuracy_sum += accuracy

            concept_analysis[concept] = {
                "correct_percent": round(accuracy, 2),
                "isAttempted": is_attempted
            }

        avg_mastery = (
            round(total_accuracy_sum / len(final_concept_order), 2)
            if len(final_concept_order) > 0 else 0
        )

        students_data.append({
            "student_name": f"{student.firstname} {student.lastname}",
            "average_mastery": avg_mastery,
            "concept_analysis": concept_analysis
        })

    # Create Excel
    wb = openpyxl.Workbook()
    ws = wb.active
    ws.title = f"Grade {grade} {section} Progress"

    header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
    header_font = Font(bold=True, color="FFFFFF", size=12)
    center_alignment = Alignment(horizontal="center", vertical="center")
    border = Border(
        left=Side(style="thin"),
        right=Side(style="thin"),
        top=Side(style="thin"),
        bottom=Side(style="thin")
    )

    ws["A1"] = "Student Name"
    ws["A1"].fill = header_fill
    ws["A1"].font = header_font
    ws["A1"].alignment = center_alignment
    ws["A1"].border = border
    ws.column_dimensions["A"].width = 25

    for idx, concept in enumerate(final_concept_order, start=2):
        cell = ws.cell(row=1, column=idx, value=concept)
        cell.fill = header_fill
        cell.font = header_font
        cell.alignment = center_alignment
        cell.border = border
        ws.column_dimensions[openpyxl.utils.get_column_letter(idx)].width = 15

    avg_col = len(final_concept_order) + 2
    avg_header = ws.cell(row=1, column=avg_col, value="Average Mastery (%)")
    avg_header.fill = header_fill
    avg_header.font = header_font
    avg_header.alignment = center_alignment
    avg_header.border = border
    ws.column_dimensions[openpyxl.utils.get_column_letter(avg_col)].width = 20

    for row_idx, student_data in enumerate(students_data, start=2):
        ws.cell(row=row_idx, column=1, value=student_data["student_name"]).border = border

        for col_idx, concept in enumerate(final_concept_order, start=2):
            concept_data = student_data["concept_analysis"].get(concept, {})
            score = concept_data.get("correct_percent", 0)
            is_attempted = concept_data.get("isAttempted", False)

            cell = ws.cell(row=row_idx, column=col_idx, value=f"{score}%")
            cell.alignment = center_alignment
            cell.border = border

            if is_attempted:
                if score >= 70:
                    cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
                elif score >= 40:
                    cell.fill = PatternFill(start_color="FFEB9C", end_color="FFEB9C", fill_type="solid")
                else:
                    cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
            else:
                cell.fill = PatternFill(start_color="E7E6E6", end_color="E7E6E6", fill_type="solid")

        avg_cell = ws.cell(
            row=row_idx,
            column=avg_col,
            value=f"{student_data['average_mastery']}%"
        )
        avg_cell.alignment = center_alignment
        avg_cell.border = border
        avg_cell.font = Font(bold=True)

    output = BytesIO()
    wb.save(output)
    output.seek(0)

    filename = f"Student_Progress_Grade_{grade}_{section}_{datetime.now().strftime('%Y%m%d')}.xlsx"

    return send_file(
        output,
        mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
        as_attachment=True,
        download_name=filename
    )


def concept_overview_func():
    teacher = g.current_user
    body = request.get_json()

    if not teacher:
        return error_response("No active teacher found", 404)

    stream = teacher.stream
    grade = body.get("grade")
    section = body.get("section")
    selected_concept = body.get("selected_concept", "").strip()

    if not grade or not section or not selected_concept:
        return error_response(
            "Invalid Payload: grade, section, and selected_concept are required",
            400
        )

    # Get all concepts with this name (case-insensitive)
    stmt_all_concepts = Select(ConceptTable).where(
        func.lower(func.trim(ConceptTable.name)) == selected_concept.lower()
    )
    all_matching_concepts = db.session.execute(stmt_all_concepts).scalars().all()

    # Find concept having subconcepts for grade & stream
    concept = None
    for candidate_concept in all_matching_concepts:
        check_stmt = Select(func.count(SubConceptTable.id)).where(
            SubConceptTable.concept_id == candidate_concept.id,
            SubConceptTable.grade == str(grade).strip(),
            func.trim(SubConceptTable.stream) == stream.strip()
        )
        count = db.session.execute(check_stmt).scalar()

        if count > 0:
            concept = candidate_concept
            break

    # Fallback
    if not concept and all_matching_concepts:
        concept = all_matching_concepts[0]

    if not concept:
        return error_response(f"Concept '{selected_concept}' not found", 404)

    # Fetch subconcepts
    stmt_subconcepts = Select(SubConceptTable).where(
        SubConceptTable.concept_id == concept.id,
        SubConceptTable.grade == str(grade).strip(),
        func.trim(SubConceptTable.stream) == stream.strip()
    ).order_by(SubConceptTable.id)

    subconcepts = db.session.execute(stmt_subconcepts).scalars().all()

    if not subconcepts:
        return error_response(
            f"No subconcepts available for '{selected_concept}' in Grade {grade}",
            404
        )

    all_subconcepts = [sc.name for sc in subconcepts]

    # Fetch students
    stmt_students = Select(StudentTable).where(
        StudentTable.grade == grade,
        StudentTable.section == section,
        StudentTable.school_id == teacher.school_id
    )
    students = db.session.execute(stmt_students).scalars().all()

    if not students:
        return jsonify({
            "success": True,
            "selected_concept": selected_concept,
            "students_analysis": [],
            "total_students": 0,
            "total_assessment_submitted": 0
        })

    students_analysis = []
    total_assessment_submitted = 0

    for student in students:
        stmt_results = Select(StudentAssesmentResultTable).where(
            StudentAssesmentResultTable.student_id == student.id,
            StudentAssesmentResultTable.grade == grade,
            StudentAssesmentResultTable.section == section
        )
        results = db.session.execute(stmt_results).scalars().all()

        student_name = (
            f"{student.firstname} {student.lastname}"
            if student.lastname else student.firstname
        )

        total_assessment_submitted += len(results)

        # Initialize all subconcepts
        subconcept_stats = {
            name: {
                "total_questions": 0,
                "correct_answers": 0,
                "attempts": 0
            }
            for name in all_subconcepts
        }
        
        attempted_subconcepts = set()  # ← Track which subconcepts were actually attempted

        # Analyze results
        for result in results:
            if result.subconcept_detail:
                for subconcept_data in result.subconcept_detail:
                    if subconcept_data.get("concept", "").strip().lower() == selected_concept.lower():
                        subconcept_name = subconcept_data.get("subconcept", "").strip()
                        total = subconcept_data.get("total", 0)
                        correct = subconcept_data.get("correct", 0)

                        if subconcept_name in subconcept_stats:
                            attempted_subconcepts.add(subconcept_name)  # ← Mark as attempted
                            subconcept_stats[subconcept_name]["total_questions"] += total
                            subconcept_stats[subconcept_name]["correct_answers"] += correct
                            subconcept_stats[subconcept_name]["attempts"] += 1

        subconcept_analysis = []
        total_mastery = 0
        subconcept_count = 0

        for subconcept_name in all_subconcepts:
            stats = subconcept_stats[subconcept_name]

            # ✅ Check if subconcept was actually attempted
            is_attempted = subconcept_name in attempted_subconcepts

            if stats["total_questions"] > 0 and is_attempted:
                correct_percent = (
                    stats["correct_answers"] / stats["total_questions"]
                ) * 100

                subconcept_analysis.append({
                    "subconcept": subconcept_name,
                    "concept": selected_concept,
                    "correct_percent": round(correct_percent, 2),
                    "isAttempted": True,
                    "total_questions": stats["total_questions"],
                    "correct_answers": stats["correct_answers"],
                    "attempts": stats["attempts"]
                })

                total_mastery += correct_percent
                subconcept_count += 1
            else:
                # Not attempted or no data
                subconcept_analysis.append({
                    "subconcept": subconcept_name,
                    "concept": selected_concept,
                    "correct_percent": 0,
                    "isAttempted": False,
                    "total_questions": 0,
                    "correct_answers": 0,
                    "attempts": 0
                })

        average_mastery = (
            total_mastery / subconcept_count
            if subconcept_count > 0 else 0
        )

        students_analysis.append({
            "student_id": student.id,
            "student_name": student_name,
            "subconcept_analysis": subconcept_analysis,
            "average_mastery": round(average_mastery, 2),
            "total_assessments": len(results)
        })

    return jsonify({
        "success": True,
        "selected_concept": selected_concept,
        "students_analysis": students_analysis,
        "total_students": len(students),
        "total_assessment_submitted": total_assessment_submitted
    })


def concept_func():
    teacher = g.current_user

    if not teacher:
        return error_response("No active teacher found.", 404)

    body = request.get_json()
    grade = body.get("grade")

    if not grade:
        return error_response("Payload required", 404)

    stmt = Select(ConceptTable).where(
        (ConceptTable.grade == str(grade)) &
        (ConceptTable.stream == teacher.stream)
    )
    result = db.session.execute(stmt).scalars().all()

    data = []

    for c in result:
        data.append({
            "id": c.id,
            "name": c.name,
            "grade": c.grade,
            "stream": c.stream,
            "description": c.description
        })

    return jsonify({
        "success": True,
        "result": data
    })


def sub_concept_func():
    teacher = g.current_user

    if not teacher:
        return error_response("No active Teacher found", 400)

    body = request.get_json()
    grade = str(body.get("grade"))
    section = body.get("section")
    concept = body.get("concept")
    stream = teacher.stream

    if not grade or not section or not concept:
        return jsonify({
            "success": False,
            "message": "Missing required fields"
        }), 400

    concept_obj = db.session.execute(
        Select(ConceptTable).where(
            ConceptTable.name == concept,
            ConceptTable.grade == grade,
            ConceptTable.stream == stream
        )
    ).scalar_one_or_none()

    if not concept_obj:
        return error_response(
            f"Concept '{concept}' not found for grade {grade}",
            400
        )

    subconcepts = db.session.execute(
        Select(SubConceptTable).where(
            SubConceptTable.concept_id == concept_obj.id
        )
    ).scalars().all()

    if not subconcepts:
        return error_response(
            f"No subconcepts found for concept: {concept}",
            404
        )

    subconcept_list = [subconcept.name for subconcept in subconcepts]

    return jsonify({
        "success": True,
        "message": "Subconcepts fetched successfully",
        "result": subconcept_list,
        "total": len(subconcept_list)
    }), 200


def remedial_func():
    teacher = g.current_user
    body = request.get_json()
    grade = body.get("grade")
    section = body.get("section")
    is_need_concept = body.get("is_need_concept", False)
    concept_name = body.get("concept_name", "overall")
    subconcept_name = body.get("subconcept_name", "All Subconcepts")

    if not grade or not section:
        return error_response("Grade or section missing", 400)

    if not teacher:
        return error_response("Unauthorized user! Please login to continue", 400)

    stream = teacher.stream
    school_id = teacher.school_id

    # =========================
    # Get all concepts if requested
    # =========================
    if is_need_concept:
        stmt_find_concept = Select(ConceptTable).where(
            ConceptTable.grade == grade,
            ConceptTable.stream == stream
        )
        all_concepts = db.session.execute(stmt_find_concept).scalars().all()
        
        concepts_list = []
        for c in all_concepts:
            stmt_subconcept = Select(SubConceptTable).where(
                SubConceptTable.concept_id == c.id
            )
            subconcepts = db.session.execute(stmt_subconcept).scalars().all()
            
            concepts_list.append({
                "id": c.id,
                "name": c.name,
                "grade": c.grade,
                "stream": c.stream,
                "subconcepts": [sc.name for sc in subconcepts]
            })
        
        overall_students = get_overall_weak_students(school_id, grade, section)
        
        return jsonify({
            "success": True,
            "concepts": concepts_list,
            "data": overall_students
        }), 200

    # =========================
    # Get specific concept/subconcept data
    # =========================
    if concept_name == "overall":
        students = get_overall_weak_students(school_id, grade, section)
    elif subconcept_name == "All Subconcepts":
        students = get_concept_weak_students(school_id, grade, section, concept_name)
    else:
        students = get_subconcept_weak_students(
            school_id, grade, section, concept_name, subconcept_name
        )

    return jsonify({
        "success": True,
        "data": students
    }), 200


