Django单元测试的核心是使用 django.test.TestCase 类,它基于 Python 的 unittest 模块并扩展了数据库事务等特性,以下是编写 Django 单元测试的完整指南:
基本结构
# tests/test_models.py
from django.test import TestCase
from myapp.models import Product
class ProductModelTest(TestCase):
def setUp(self):
"""每个测试方法执行前运行"""
self.product = Product.objects.create(
name="Test Product",
price=99.99,
stock=10
)
def test_product_creation(self):
"""测试产品创建"""
self.assertEqual(self.product.name, "Test Product")
self.assertTrue(self.product.is_available())
def test_product_string_representation(self):
"""测试字符串表示"""
self.assertEqual(str(self.product), "Test Product")
测试模型的常用写法
class ProductModelTest(TestCase):
@classmethod
def setUpTestData(cls):
"""类级别初始化,所有测试共享"""
cls.category = Category.objects.create(name="Electronics")
def setUp(self):
"""方法级别初始化,每个测试独立"""
self.product = Product.objects.create(
name="Laptop",
category=self.category,
price=1999.99
)
def test_price_is_positive(self):
"""测试价格验证"""
with self.assertRaises(ValidationError):
product = Product(name="Invalid", price=-10)
product.full_clean()
def test_stock_default_value(self):
"""测试默认值"""
self.assertEqual(self.product.stock, 0)
测试视图
# tests/test_views.py
from django.test import TestCase, Client
from django.urls import reverse
class ProductViewTest(TestCase):
def setUp(self):
self.client = Client()
self.product = Product.objects.create(
name="Test Product",
price=100,
stock=5
)
def test_product_list_view(self):
"""测试产品列表视图"""
url = reverse('product-list')
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, 'products/list.html')
self.assertContains(response, "Test Product")
def test_product_detail_view(self):
"""测试产品详情视图"""
url = reverse('product-detail', args=[self.product.id])
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertContains(response, self.product.name)
def test_product_create_view_get(self):
"""测试创建页面GET请求"""
url = reverse('product-create')
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'form')
def test_product_create_view_post(self):
"""测试创建页面POST请求"""
url = reverse('product-create')
data = {
'name': 'New Product',
'price': 50,
'stock': 10
}
response = self.client.post(url, data)
self.assertEqual(response.status_code, 302) # 重定向
self.assertEqual(Product.objects.count(), 2)
def test_product_list_pagination(self):
"""测试列表分页"""
for i in range(25):
Product.objects.create(name=f"Product {i}", price=10)
url = reverse('product-list')
response = self.client.get(url)
self.assertEqual(len(response.context['products']), 10) # 默认每页10条
测试表单
# tests/test_forms.py
from django.test import TestCase
from myapp.forms import ProductForm
class ProductFormTest(TestCase):
def test_product_form_valid(self):
"""测试有效表单"""
form_data = {
'name': 'New Product',
'price': 99.99,
'stock': 10
}
form = ProductForm(data=form_data)
self.assertTrue(form.is_valid())
def test_product_form_invalid(self):
"""测试无效表单"""
form_data = {
'name': '',
'price': -10,
'stock': -5
}
form = ProductForm(data=form_data)
self.assertFalse(form.is_valid())
self.assertIn('name', form.errors)
self.assertIn('price', form.errors)
def test_product_form_price_validation(self):
"""测试自定义价格验证"""
form_data = {
'name': 'Expensive Product',
'price': 10000,
'stock': 1
}
form = ProductForm(data=form_data)
self.assertFalse(form.is_valid())
测试 API (Django REST Framework)
# tests/test_api.py
from rest_framework.test import APITestCase
from rest_framework import status
from django.urls import reverse
class ProductAPITest(APITestCase):
def setUp(self):
self.product = Product.objects.create(
name="API Product",
price=199.99,
stock=20
)
def test_list_products(self):
"""测试API产品列表"""
url = reverse('product-api-list')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
def test_create_product(self):
"""测试API创建产品"""
url = reverse('product-api-list')
data = {'name': 'New Product', 'price': 50, 'stock': 10}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Product.objects.count(), 2)
def test_retrieve_product(self):
"""测试API获取单个产品"""
url = reverse('product-api-detail', args=[self.product.id])
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['name'], self.product.name)
测试需要认证的视图
from django.contrib.auth.models import User
from django.test import TestCase, Client
class AuthViewTest(TestCase):
def setUp(self):
self.client = Client()
self.user = User.objects.create_user(
username='testuser',
password='testpass123'
)
def test_unauthenticated_access(self):
"""测试未认证访问"""
url = reverse('dashboard')
response = self.client.get(url)
self.assertEqual(response.status_code, 302) # 重定向到登录页
def test_authenticated_access(self):
"""测试认证访问"""
self.client.login(username='testuser', password='testpass123')
url = reverse('dashboard')
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, 'dashboard.html')
def test_login_required_decorator(self):
"""测试login_required装饰器"""
from django.contrib.auth.decorators import login_required
@login_required
def protected_view(request):
return HttpResponse()
request = self.client.get('/protected/')
self.assertNotEqual(request.status_code, 200)
测试信号
# tests/test_signals.py
from django.test import TestCase
from unittest.mock import patch
from myapp.signals import product_created
class SignalTest(TestCase):
@patch('myapp.signals.send_welcome_email')
def test_product_created_signal(self, mock_send_email):
"""测试产品创建信号"""
product = Product.objects.create(
name="Signal Product",
price=100
)
# 验证信号处理函数被调用
mock_send_email.assert_called_once_with(product)
def test_signal_receiver_count(self):
"""测试信号接收器数量"""
from django.db.models.signals import post_save
receivers = post_save.receivers
product_receivers = [
r for r in receivers
if r[1].__name__ == 'handle_product_save'
]
self.assertEqual(len(product_receivers), 1)
使用测试工具
# test_with_tools.py
from django.test import TestCase, override_settings
from unittest.mock import patch, MagicMock
import tempfile
class ToolTest(TestCase):
@override_settings(DEBUG=True)
def test_debug_mode(self):
"""临时覆盖设置"""
from django.conf import settings
self.assertTrue(settings.DEBUG)
@patch('myapp.services.calculate_tax')
def test_mock_external_service(self, mock_calculate):
"""模拟外部服务"""
mock_calculate.return_value = 15.00
result = process_order(self.order)
self.assertEqual(result['tax'], 15.00)
mock_calculate.assert_called_once_with(self.order)
def test_file_upload(self):
"""测试文件上传"""
import io
from django.core.files.uploadedfile import SimpleUploadedFile
small_gif = (
b'\x47\x49\x46\x38\x39\x61'
b'\x01\x00\x01\x00\x00\x00'
b'\x00\x21\xf9\x04\x00\x00'
b'\x00\x00\x00\x2c\x00\x00'
b'\x00\x00\x01\x00\x01\x00'
b'\x00\x02\x02\x44\x01\x00'
b'\x3b'
)
uploaded = SimpleUploadedFile(
'small.gif',
small_gif,
content_type='image/gif'
)
response = self.client.post('/upload/', {'file': uploaded})
self.assertEqual(response.status_code, 200)
组织和运行测试
目录结构
myapp/
├── tests/
│ ├── __init__.py
│ ├── test_models.py
│ ├── test_views.py
│ ├── test_forms.py
│ └── test_api.py
运行测试
# 运行所有测试 python manage.py test # 运行特定app的测试 python manage.py test myapp # 运行特定测试类 python manage.py test myapp.tests.test_models.ProductModelTest # 运行特定测试方法 python manage.py test myapp.tests.test_models.ProductModelTest.test_product_creation # 指定测试目录 python manage.py test myapp.tests # 并行运行测试 (需要多个核心) python manage.py test --parallel
覆盖率报告
# 安装覆盖率工具 pip install coverage # 运行测试并生成报告 coverage run --source='myapp' manage.py test myapp coverage report coverage html # 生成HTML报告
最佳实践
-
测试命名规范
- 测试类:
Test<Feature> - 测试方法:
test_<action>_<expected_result>
- 测试类:
-
保持测试独立
- 每个测试方法应该独立运行
- 不要依赖测试执行顺序
-
使用工厂模式
# 使用factory_boy简化测试数据创建 import factory from myapp.models import Product
class ProductFactory(factory.django.DjangoModelFactory): class Meta: model = Product
name = factory.Sequence(lambda n: f'Product {n}')
price = factory.Faker('pydecimal', left_digits=3, right_digits=2)
stock = factory.Faker('random_int', min=0, max=100)
在测试中使用
product = ProductFactory() product2 = ProductFactory(price=200.00)
4. **测试速度优化**
- 使用 `setUpTestData()` 代替 `setUp()` 当数据不变时
- 使用 `TestCase` 以外,可用 `TransactionTestCase` 但更慢
- 合理使用 `@override_settings` 而不是修改配置文件
5. **异步测试**
```python
from django.test import TestCase
import asyncio
class AsyncTest(TestCase):
async def test_async_view(self):
response = await self.async_client.get('/async-endpoint/')
self.assertEqual(response.status_code, 200)
Django 单元测试是确保代码质量的基石,通过系统化的测试覆盖,可以及早发现问题、简化重构过程,并为项目长期维护提供保障。