from sqlalchemy import Select, func, case, and_
from src.models import RankingTable, StudentTable
from src import db


def calculate_student_metrics(student_id: int):
    """Calculate overall, speed, and accuracy metrics for a student"""
    
    # Get all correct attempts for the student
    stmt = Select(
        func.count(RankingTable.id).label('total_correct'),
        func.avg(RankingTable.time_taken_sec).label('avg_time'),
        func.sum(case((RankingTable.is_correct == True, 1), else_=0)).label('correct_count'),
        func.count(RankingTable.id).label('total_attempts')
    ).where(RankingTable.student_id == student_id)
    
    result = db.session.execute(stmt).first()
    
    if not result or not result.total_correct:
        return {
            'total_correct': 0,
            'avg_time': 0,
            'accuracy': 0
        }
    
    total_correct = result.total_correct or 0
    avg_time = result.avg_time or 0
    correct_count = result.correct_count or 0
    total_attempts = result.total_attempts or 1
    
    accuracy = (correct_count / total_attempts) * 100 if total_attempts > 0 else 0
    
    return {
        'total_correct': total_correct,
        'avg_time': avg_time,
        'accuracy': accuracy
    }

def get_class_rankings(student_id: int, grade: str, section: str, school_id: int):
    """Get rankings for students in the same class"""
    
    # Get all students in the same class
    students_stmt = Select(StudentTable.id).where(
        and_(
            StudentTable.grade == grade,
            StudentTable.section == section,
            StudentTable.school_id == school_id
        )
    )
    students = db.session.execute(students_stmt).scalars().all()
    
    if not students:
        return {'overall': 1, 'speed': 1, 'accuracy': 1}
    
    # Calculate metrics for all students in class
    student_metrics = []
    for stu_id in students:
        metrics = calculate_student_metrics(stu_id)
        student_metrics.append({
            'student_id': stu_id,
            **metrics
        })
    
    # Sort for overall ranking (by total correct questions, descending)
    overall_sorted = sorted(student_metrics, key=lambda x: x['total_correct'], reverse=True)
    
    # Sort for speed ranking (by avg time, ascending - lower is better)
    # Only consider students who have answered at least 1 question
    speed_sorted = sorted(
        [s for s in student_metrics if s['total_correct'] > 0],
        key=lambda x: x['avg_time']
    )
    
    # Sort for accuracy ranking (by accuracy percentage, descending)
    accuracy_sorted = sorted(
        [s for s in student_metrics if s['total_correct'] > 0],
        key=lambda x: x['accuracy'],
        reverse=True
    )
    
    # Find current student's rank
    overall_rank = next((i + 1 for i, s in enumerate(overall_sorted) if s['student_id'] == student_id), len(students))
    speed_rank = next((i + 1 for i, s in enumerate(speed_sorted) if s['student_id'] == student_id), len(students))
    accuracy_rank = next((i + 1 for i, s in enumerate(accuracy_sorted) if s['student_id'] == student_id), len(students))
    
    return {
        'overall': overall_rank,
        'speed': speed_rank,
        'accuracy': accuracy_rank
    }
    
def get_school_rankings(student_id: int, school_id: int):
    """Get rankings for students in the entire school"""
    
    # Get all students in the same school
    students_stmt = Select(StudentTable.id).where(StudentTable.school_id == school_id)
    students = db.session.execute(students_stmt).scalars().all()
    
    if not students:
        return {'overall': 1, 'speed': 1, 'accuracy': 1}
    
    # Calculate metrics for all students in school
    student_metrics = []
    for stu_id in students:
        metrics = calculate_student_metrics(stu_id)
        student_metrics.append({
            'student_id': stu_id,
            **metrics
        })
    
    # Sort for overall ranking (by total correct questions, descending)
    overall_sorted = sorted(student_metrics, key=lambda x: x['total_correct'], reverse=True)
    
    # Sort for speed ranking (by avg time, ascending - lower is better)
    speed_sorted = sorted(
        [s for s in student_metrics if s['total_correct'] > 0],
        key=lambda x: x['avg_time']
    )
    
    # Sort for accuracy ranking (by accuracy percentage, descending)
    accuracy_sorted = sorted(
        [s for s in student_metrics if s['total_correct'] > 0],
        key=lambda x: x['accuracy'],
        reverse=True
    )
    
    # Find current student's rank
    overall_rank = next((i + 1 for i, s in enumerate(overall_sorted) if s['student_id'] == student_id), len(students))
    speed_rank = next((i + 1 for i, s in enumerate(speed_sorted) if s['student_id'] == student_id), len(students))
    accuracy_rank = next((i + 1 for i, s in enumerate(accuracy_sorted) if s['student_id'] == student_id), len(students))
    
    return {
        'overall': overall_rank,
        'speed': speed_rank,
        'accuracy': accuracy_rank
    }

def update_student_rankings(student_id: int, grade: str, section: str, school_id: int):
    """Update rankings for a student at login"""
    
    # Get class rankings
    class_ranks = get_class_rankings(student_id, grade, section, school_id)
    
    # Get school rankings
    school_ranks = get_school_rankings(student_id, school_id)
    
    # Update student record with rankings
    stmt = Select(StudentTable).where(StudentTable.id == student_id)
    student = db.session.execute(stmt).scalar_one_or_none()
    
    if student:
        student.overall_ranking = class_ranks['overall']
        student.speed_ranking = class_ranks['speed']
        student.accuracy_ranking = class_ranks['accuracy']
        db.session.commit()
    
    return {
        'class_rankings': class_ranks,
        'school_rankings': school_ranks
    }


def get_surrounding_students(student_id: int, grade: str, section: str, school_id: int, ranking_type: str = 'overall'):
    """Get 3 students around current student's position (previous, current, next)"""
    
    # Get all students in the same class
    students_stmt = Select(StudentTable.id).where(
        and_(
            StudentTable.grade == grade,
            StudentTable.section == section,
            StudentTable.school_id == school_id
        )
    )
    students = db.session.execute(students_stmt).scalars().all()
    
    if not students:
        return []
    
    # Calculate metrics for all students
    student_metrics = []
    for stu_id in students:
        metrics = calculate_student_metrics(stu_id)
        
        # Get student details
        stu_stmt = Select(StudentTable).where(StudentTable.id == stu_id)
        stu = db.session.execute(stu_stmt).scalar_one_or_none()
        
        if stu:
            student_metrics.append({
                'student_id': stu_id,
                'username': stu.username,
                'firstname': stu.firstname,
                'lastname': stu.lastname,
                'avatar': stu.avatar,
                **metrics
            })
    
    # Sort based on ranking type
    if ranking_type == 'overall':
        sorted_students = sorted(student_metrics, key=lambda x: x['total_correct'], reverse=True)
    elif ranking_type == 'speed':
        sorted_students = sorted(
            [s for s in student_metrics if s['total_correct'] > 0],
            key=lambda x: x['avg_time']
        )
    elif ranking_type == 'accuracy':
        sorted_students = sorted(
            [s for s in student_metrics if s['total_correct'] > 0],
            key=lambda x: x['accuracy'],
            reverse=True
        )
    else:
        sorted_students = student_metrics
    
    # Find current student's index
    current_idx = next((i for i, s in enumerate(sorted_students) if s['student_id'] == student_id), -1)
    
    if current_idx == -1:
        return []
    
    # Get surrounding students (previous, current, next)
    result = []
    
    # Add previous student if exists
    if current_idx > 0:
        prev_student = sorted_students[current_idx - 1]
        result.append({
            'rank': current_idx,
            **prev_student
        })
    
    # Add current student
    current_student = sorted_students[current_idx]
    result.append({
        'rank': current_idx + 1,
        'is_current': True,
        **current_student
    })
    
    # Add next student if exists
    if current_idx < len(sorted_students) - 1:
        next_student = sorted_students[current_idx + 1]
        result.append({
            'rank': current_idx + 2,
            **next_student
        })
    
    return result

