import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.core.services.auth import AuthService
from app.config import settings

@pytest.mark.asyncio
async def test_handle_callback_success():
    services = MagicMock()
    # Mock user_service.save_user
    services.user_service.save_user.return_value = ("user_123", True)
    
    auth_service = AuthService(services)
    
    # Mock db
    db = MagicMock()
    
    # Mock OIDC responses
    mock_token_resp = MagicMock()
    mock_token_resp.json.return_value = {"id_token": "mock_id_token"}
    mock_token_resp.raise_for_status = MagicMock()
    
    mock_jwks_resp = MagicMock()
    mock_jwks_resp.json.return_value = {"keys": [{"kid": "key_1", "kty": "RSA"}]}
    mock_jwks_resp.raise_for_status = MagicMock()
    
    async def mock_post(*args, **kwargs):
        return mock_token_resp
        
    async def mock_get(*args, **kwargs):
        return mock_jwks_resp
        
    # We need to mock AsyncClient instance methods
    mock_client = MagicMock()
    mock_client.post = AsyncMock(side_effect=mock_post)
    mock_client.get = AsyncMock(side_effect=mock_get)
    
    # Mock __aenter__ and __aexit__ for context manager
    mock_client.__aenter__.return_value = mock_client
    mock_client.__aexit__.return_value = None
    
    with patch('httpx.AsyncClient', return_value=mock_client):
        with patch('jwt.get_unverified_header', return_value={"kid": "key_1"}):
            with patch('jwt.PyJWKSet.from_dict') as mock_jwk_set:
                mock_key = MagicMock()
                mock_key.key = "publicKey"
                mock_jwk_set.return_value.__getitem__.return_value = mock_key
                
                with patch('jwt.decode', return_value={"sub": "oidc_123", "email": "user@example.com"}):
                    result = await auth_service.handle_callback(code="test_code", db=db)
                    
                    assert result["user_id"] == "user_123"
                    assert result["linked"] is True
