Files
stegasoo/tests/test_batch.py
2025-12-30 00:28:58 -05:00

292 lines
9.5 KiB
Python

"""
Tests for Stegasoo batch processing module.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch
from stegasoo.batch import (
BatchProcessor,
BatchResult,
BatchItem,
BatchStatus,
batch_capacity_check,
print_batch_result,
)
@pytest.fixture
def temp_dir():
"""Create a temporary directory for tests."""
path = Path(tempfile.mkdtemp())
yield path
shutil.rmtree(path)
@pytest.fixture
def sample_images(temp_dir):
"""Create sample PNG images for testing."""
from PIL import Image
images = []
for i in range(3):
img_path = temp_dir / f"test_image_{i}.png"
img = Image.new('RGB', (100, 100), color=(i * 50, i * 50, i * 50))
img.save(img_path, 'PNG')
images.append(img_path)
return images
class TestBatchItem:
"""Tests for BatchItem dataclass."""
def test_duration_calculation(self):
"""Duration should be calculated from start/end times."""
item = BatchItem(input_path=Path("test.png"))
item.start_time = 100.0
item.end_time = 105.5
assert item.duration == 5.5
def test_duration_none_without_times(self):
"""Duration should be None if times not set."""
item = BatchItem(input_path=Path("test.png"))
assert item.duration is None
def test_to_dict(self):
"""to_dict should serialize all fields."""
item = BatchItem(
input_path=Path("input.png"),
output_path=Path("output.png"),
status=BatchStatus.SUCCESS,
message="Done",
)
result = item.to_dict()
assert result['input_path'] == "input.png"
assert result['output_path'] == "output.png"
assert result['status'] == "success"
class TestBatchResult:
"""Tests for BatchResult dataclass."""
def test_to_json(self):
"""Should serialize to valid JSON."""
import json
result = BatchResult(operation="encode", total=5, succeeded=4, failed=1)
json_str = result.to_json()
parsed = json.loads(json_str)
assert parsed['operation'] == "encode"
assert parsed['summary']['total'] == 5
def test_duration_with_end_time(self):
"""Duration should work when end_time is set."""
result = BatchResult(operation="test")
result.start_time = 100.0
result.end_time = 110.0
assert result.duration == 10.0
class TestBatchProcessor:
"""Tests for BatchProcessor class."""
def test_init_default_workers(self):
"""Should default to 4 workers."""
processor = BatchProcessor()
assert processor.max_workers == 4
def test_init_custom_workers(self):
"""Should accept custom worker count."""
processor = BatchProcessor(max_workers=8)
assert processor.max_workers == 8
def test_is_valid_image_png(self, temp_dir):
"""Should recognize PNG as valid."""
processor = BatchProcessor()
png_path = temp_dir / "test.png"
png_path.touch()
assert processor._is_valid_image(png_path)
def test_is_valid_image_txt(self, temp_dir):
"""Should reject non-image files."""
processor = BatchProcessor()
txt_path = temp_dir / "test.txt"
txt_path.touch()
assert not processor._is_valid_image(txt_path)
def test_find_images_file(self, sample_images):
"""Should find single image file."""
processor = BatchProcessor()
results = list(processor.find_images([sample_images[0]]))
assert len(results) == 1
assert results[0] == sample_images[0]
def test_find_images_directory(self, sample_images, temp_dir):
"""Should find images in directory."""
processor = BatchProcessor()
results = list(processor.find_images([temp_dir]))
assert len(results) == 3
def test_find_images_recursive(self, temp_dir):
"""Should find images recursively."""
from PIL import Image
# Create nested directory
nested = temp_dir / "nested"
nested.mkdir()
img_path = nested / "nested.png"
img = Image.new('RGB', (50, 50))
img.save(img_path)
processor = BatchProcessor()
results = list(processor.find_images([temp_dir], recursive=True))
assert any(p.name == "nested.png" for p in results)
def test_batch_encode_requires_message_or_file(self, sample_images):
"""Should raise if neither message nor file provided."""
processor = BatchProcessor()
with pytest.raises(ValueError, match="message or file_payload"):
processor.batch_encode(
images=sample_images,
credentials={"phrase": "test", "pin": "123456"},
)
def test_batch_encode_requires_credentials(self, sample_images):
"""Should raise if credentials not provided."""
processor = BatchProcessor()
with pytest.raises(ValueError, match="Credentials"):
processor.batch_encode(
images=sample_images,
message="test",
)
def test_batch_encode_creates_result(self, sample_images, temp_dir):
"""Should return BatchResult with correct structure."""
processor = BatchProcessor()
result = processor.batch_encode(
images=sample_images,
message="Test message",
output_dir=temp_dir / "output",
credentials={"phrase": "test phrase", "pin": "123456"},
)
assert isinstance(result, BatchResult)
assert result.operation == "encode"
assert result.total == 3
assert len(result.items) == 3
def test_batch_decode_requires_credentials(self, sample_images):
"""Should raise if credentials not provided."""
processor = BatchProcessor()
with pytest.raises(ValueError, match="Credentials"):
processor.batch_decode(images=sample_images)
def test_batch_decode_creates_result(self, sample_images):
"""Should return BatchResult with correct structure."""
processor = BatchProcessor()
result = processor.batch_decode(
images=sample_images,
credentials={"phrase": "test phrase", "pin": "123456"},
)
assert isinstance(result, BatchResult)
assert result.operation == "decode"
assert result.total == 3
def test_progress_callback_called(self, sample_images):
"""Progress callback should be called for each item."""
processor = BatchProcessor()
callback = Mock()
processor.batch_encode(
images=sample_images,
message="Test",
credentials={"phrase": "test", "pin": "123456"},
progress_callback=callback,
)
assert callback.call_count == 3
def test_custom_encode_func(self, sample_images, temp_dir):
"""Should use custom encode function if provided."""
processor = BatchProcessor()
encode_mock = Mock()
processor.batch_encode(
images=sample_images,
message="Test",
output_dir=temp_dir / "output",
credentials={"phrase": "test", "pin": "123456"},
encode_func=encode_mock,
)
assert encode_mock.call_count == 3
class TestBatchCapacityCheck:
"""Tests for batch_capacity_check function."""
def test_returns_list(self, sample_images):
"""Should return list of results."""
results = batch_capacity_check(sample_images)
assert isinstance(results, list)
assert len(results) == 3
def test_includes_capacity(self, sample_images):
"""Results should include capacity info."""
results = batch_capacity_check(sample_images)
for item in results:
assert 'capacity_bytes' in item
assert 'dimensions' in item
assert 'valid' in item
def test_handles_invalid_files(self, temp_dir):
"""Should handle non-image files gracefully."""
bad_file = temp_dir / "not_an_image.png"
bad_file.write_bytes(b"not a png")
results = batch_capacity_check([bad_file])
assert len(results) == 1
assert 'error' in results[0]
class TestPrintBatchResult:
"""Tests for print_batch_result function."""
def test_prints_summary(self, capsys, sample_images):
"""Should print summary without errors."""
result = BatchResult(
operation="encode",
total=3,
succeeded=2,
failed=1,
)
result.end_time = result.start_time + 5.0
print_batch_result(result)
captured = capsys.readouterr()
assert "ENCODE" in captured.out
assert "3" in captured.out # total
assert "2" in captured.out # succeeded
def test_verbose_shows_items(self, capsys):
"""Verbose mode should show individual items."""
result = BatchResult(operation="decode", total=1, succeeded=1)
result.items = [
BatchItem(
input_path=Path("test.png"),
status=BatchStatus.SUCCESS,
message="Decoded successfully",
)
]
result.end_time = result.start_time + 1.0
print_batch_result(result, verbose=True)
captured = capsys.readouterr()
assert "test.png" in captured.out