diff --git a/crawl4ai/deep_crawling/base_strategy.py b/crawl4ai/deep_crawling/base_strategy.py index e1b3fe6bd..55cd6e5b5 100644 --- a/crawl4ai/deep_crawling/base_strategy.py +++ b/crawl4ai/deep_crawling/base_strategy.py @@ -32,13 +32,13 @@ async def result_wrapper(): async for result in result_obj: yield result finally: - self.deep_crawl_active.reset(token) + self.deep_crawl_active.set(False) return result_wrapper() else: try: return result_obj finally: - self.deep_crawl_active.reset(token) + self.deep_crawl_active.set(False) return await original_arun(url, config=config, **kwargs) return wrapped_arun diff --git a/tests/deep_crawling/test_deep_crawl_contextvar.py b/tests/deep_crawling/test_deep_crawl_contextvar.py new file mode 100644 index 000000000..4f1bfd436 --- /dev/null +++ b/tests/deep_crawling/test_deep_crawl_contextvar.py @@ -0,0 +1,340 @@ +""" +Test Suite: Deep Crawl ContextVar Safety (Issue #1917) + +Tests that DeepCrawlDecorator's ContextVar (deep_crawl_active) works correctly +when the async generator is consumed in a different asyncio context, as happens +with Starlette's StreamingResponse in the Docker API. + +The bug: base_strategy.py used ContextVar.reset(token) in the generator's finally +block, but reset() requires the same Context that created the token. When Starlette +consumes the generator in a different Task, the Context changes -> ValueError. + +The fix: use ContextVar.set(False) instead of reset(token), which works across +context boundaries. +""" + +import pytest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from crawl4ai.deep_crawling.base_strategy import DeepCrawlDecorator + + +# ============================================================================ +# Helpers +# ============================================================================ + +def create_mock_result(url="https://example.com"): + result = MagicMock() + result.url = url + result.success = True + result.metadata = {} + result.links = {"internal": [], "external": []} + return result + + +def create_streaming_strategy(results): + """Create a mock deep crawl strategy that streams results.""" + strategy = MagicMock() + + async def mock_arun(start_url, crawler, config): + async def gen(): + for r in results: + yield r + return gen() + + strategy.arun = mock_arun + return strategy + + +def create_batch_strategy(results): + """Create a mock deep crawl strategy that returns results as a list.""" + strategy = MagicMock() + + async def mock_arun(start_url, crawler, config): + return results + + strategy.arun = mock_arun + return strategy + + +def create_config(stream=False, strategy=None): + config = MagicMock() + config.stream = stream + config.deep_crawl_strategy = strategy + return config + + +# ============================================================================ +# Tests: ContextVar cross-context safety (the core #1917 bug) +# ============================================================================ + +class TestContextVarCrossContext: + """Tests that deep_crawl_active ContextVar works across task boundaries.""" + + @pytest.mark.asyncio + async def test_streaming_generator_consumed_in_different_task(self): + """ + Core reproduction of issue #1917: + Create the generator in one task, consume it in another. + Before the fix, this raised ValueError. + """ + mock_results = [create_mock_result(f"https://example.com/{i}") for i in range(3)] + strategy = create_streaming_strategy(mock_results) + config = create_config(stream=True, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + + # Call wrapped_arun in the current task — sets the token + gen = await wrapped("https://example.com", config=config) + + # Consume in a DIFFERENT task (simulates Starlette's StreamingResponse) + collected = [] + + async def consume_in_new_task(): + async for result in gen: + collected.append(result) + + task = asyncio.create_task(consume_in_new_task()) + await task + + assert len(collected) == 3 + + @pytest.mark.asyncio + async def test_batch_mode_in_different_task(self): + """Non-streaming mode should also work across task boundaries.""" + mock_results = [create_mock_result("https://example.com")] + strategy = create_batch_strategy(mock_results) + config = create_config(stream=False, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + result = await wrapped("https://example.com", config=config) + + assert result == mock_results + + +# ============================================================================ +# Tests: ContextVar state management +# ============================================================================ + +class TestContextVarState: + """Tests that deep_crawl_active is properly managed.""" + + @pytest.mark.asyncio + async def test_flag_is_false_after_streaming_completes(self): + """deep_crawl_active should be False after the generator is exhausted.""" + strategy = create_streaming_strategy([create_mock_result()]) + config = create_config(stream=True, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + gen = await wrapped("https://example.com", config=config) + + async for _ in gen: + pass + + assert decorator.deep_crawl_active.get() == False + + @pytest.mark.asyncio + async def test_flag_is_false_after_batch_completes(self): + """deep_crawl_active should be False after batch mode completes.""" + strategy = create_batch_strategy([create_mock_result()]) + config = create_config(stream=False, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + await wrapped("https://example.com", config=config) + + assert decorator.deep_crawl_active.get() == False + + @pytest.mark.asyncio + async def test_flag_is_true_during_deep_crawl(self): + """deep_crawl_active should be True while the generator is being consumed.""" + flag_during_yield = None + + async def capturing_arun(start_url, crawler, config): + async def gen(): + nonlocal flag_during_yield + flag_during_yield = DeepCrawlDecorator.deep_crawl_active.get() + yield create_mock_result(start_url) + return gen() + + strategy = MagicMock() + strategy.arun = capturing_arun + config = create_config(stream=True, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + gen = await wrapped("https://example.com", config=config) + + async for _ in gen: + pass + + assert flag_during_yield == True + + @pytest.mark.asyncio + async def test_flag_prevents_recursive_deep_crawl(self): + """When deep_crawl_active is True, nested calls should skip deep crawl.""" + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + inner_call_hit = False + + async def original_arun(url, config=None, **kwargs): + nonlocal inner_call_hit + inner_call_hit = True + return create_mock_result(url) + + wrapped = decorator(original_arun) + + # Manually set the flag to simulate being inside a deep crawl + decorator.deep_crawl_active.set(True) + try: + strategy = create_batch_strategy([create_mock_result()]) + config = create_config(stream=False, strategy=strategy) + # Should call original_arun directly, NOT go through strategy + result = await wrapped("https://example.com", config=config) + assert inner_call_hit == True + finally: + decorator.deep_crawl_active.set(False) + + @pytest.mark.asyncio + async def test_flag_reset_after_streaming_error(self): + """deep_crawl_active should be reset even if the generator raises.""" + strategy = MagicMock() + + async def failing_arun(start_url, crawler, config): + async def gen(): + yield create_mock_result("https://example.com") + raise RuntimeError("simulated error") + return gen() + + strategy.arun = failing_arun + config = create_config(stream=True, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + gen = await wrapped("https://example.com", config=config) + + with pytest.raises(RuntimeError, match="simulated error"): + async for _ in gen: + pass + + assert decorator.deep_crawl_active.get() == False + + @pytest.mark.asyncio + async def test_flag_reset_after_streaming_error_in_different_task(self): + """ + Combines #1917 fix with error handling: generator raises in a different task. + Both the cross-context issue and error cleanup must work together. + """ + strategy = MagicMock() + + async def failing_arun(start_url, crawler, config): + async def gen(): + yield create_mock_result("https://example.com") + raise RuntimeError("simulated error") + return gen() + + strategy.arun = failing_arun + config = create_config(stream=True, strategy=strategy) + + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + gen = await wrapped("https://example.com", config=config) + + error_caught = False + + async def consume_in_new_task(): + nonlocal error_caught + try: + async for _ in gen: + pass + except RuntimeError: + error_caught = True + + task = asyncio.create_task(consume_in_new_task()) + await task + + assert error_caught == True + + +# ============================================================================ +# Tests: Concurrent requests +# ============================================================================ + +class TestConcurrentRequests: + """Tests that multiple concurrent streaming deep crawls don't interfere.""" + + @pytest.mark.asyncio + async def test_concurrent_streaming_in_separate_tasks(self): + """ + Multiple concurrent streaming requests consumed in separate tasks. + This simulates multiple clients hitting /crawl/stream simultaneously. + """ + crawler = MagicMock() + decorator = DeepCrawlDecorator(crawler) + + async def original_arun(url, config=None, **kwargs): + return create_mock_result(url) + + wrapped = decorator(original_arun) + results_per_request = {} + + async def simulate_request(request_id): + mock_results = [create_mock_result(f"https://example.com/{request_id}/{i}") for i in range(2)] + strategy = create_streaming_strategy(mock_results) + config = create_config(stream=True, strategy=strategy) + + gen = await wrapped(f"https://example.com/{request_id}", config=config) + results = [] + async for result in gen: + results.append(result) + await asyncio.sleep(0.01) # Interleave with other requests + results_per_request[request_id] = results + + tasks = [asyncio.create_task(simulate_request(i)) for i in range(3)] + await asyncio.gather(*tasks) + + assert len(results_per_request) == 3 + for request_id, results in results_per_request.items(): + assert len(results) == 2 + + assert decorator.deep_crawl_active.get() == False