Django单元测试如何编写?

访客 全栈框架 1

Django单元测试的核心是使用 django.test.TestCase 类,它基于 Python 的 unittest 模块并扩展了数据库事务等特性,以下是编写 Django 单元测试的完整指南:

基本结构

# tests/test_models.py
from django.test import TestCase
from myapp.models import Product
class ProductModelTest(TestCase):
    def setUp(self):
        """每个测试方法执行前运行"""
        self.product = Product.objects.create(
            name="Test Product",
            price=99.99,
            stock=10
        )
    def test_product_creation(self):
        """测试产品创建"""
        self.assertEqual(self.product.name, "Test Product")
        self.assertTrue(self.product.is_available())
    def test_product_string_representation(self):
        """测试字符串表示"""
        self.assertEqual(str(self.product), "Test Product")

测试模型的常用写法

class ProductModelTest(TestCase):
    @classmethod
    def setUpTestData(cls):
        """类级别初始化,所有测试共享"""
        cls.category = Category.objects.create(name="Electronics")
    def setUp(self):
        """方法级别初始化,每个测试独立"""
        self.product = Product.objects.create(
            name="Laptop",
            category=self.category,
            price=1999.99
        )
    def test_price_is_positive(self):
        """测试价格验证"""
        with self.assertRaises(ValidationError):
            product = Product(name="Invalid", price=-10)
            product.full_clean()
    def test_stock_default_value(self):
        """测试默认值"""
        self.assertEqual(self.product.stock, 0)

测试视图

# tests/test_views.py
from django.test import TestCase, Client
from django.urls import reverse
class ProductViewTest(TestCase):
    def setUp(self):
        self.client = Client()
        self.product = Product.objects.create(
            name="Test Product",
            price=100,
            stock=5
        )
    def test_product_list_view(self):
        """测试产品列表视图"""
        url = reverse('product-list')
        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)
        self.assertTemplateUsed(response, 'products/list.html')
        self.assertContains(response, "Test Product")
    def test_product_detail_view(self):
        """测试产品详情视图"""
        url = reverse('product-detail', args=[self.product.id])
        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)
        self.assertContains(response, self.product.name)
    def test_product_create_view_get(self):
        """测试创建页面GET请求"""
        url = reverse('product-create')
        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)
        self.assertContains(response, 'form')
    def test_product_create_view_post(self):
        """测试创建页面POST请求"""
        url = reverse('product-create')
        data = {
            'name': 'New Product',
            'price': 50,
            'stock': 10
        }
        response = self.client.post(url, data)
        self.assertEqual(response.status_code, 302)  # 重定向
        self.assertEqual(Product.objects.count(), 2)
    def test_product_list_pagination(self):
        """测试列表分页"""
        for i in range(25):
            Product.objects.create(name=f"Product {i}", price=10)
        url = reverse('product-list')
        response = self.client.get(url)
        self.assertEqual(len(response.context['products']), 10)  # 默认每页10条

测试表单

# tests/test_forms.py
from django.test import TestCase
from myapp.forms import ProductForm
class ProductFormTest(TestCase):
    def test_product_form_valid(self):
        """测试有效表单"""
        form_data = {
            'name': 'New Product',
            'price': 99.99,
            'stock': 10
        }
        form = ProductForm(data=form_data)
        self.assertTrue(form.is_valid())
    def test_product_form_invalid(self):
        """测试无效表单"""
        form_data = {
            'name': '',
            'price': -10,
            'stock': -5
        }
        form = ProductForm(data=form_data)
        self.assertFalse(form.is_valid())
        self.assertIn('name', form.errors)
        self.assertIn('price', form.errors)
    def test_product_form_price_validation(self):
        """测试自定义价格验证"""
        form_data = {
            'name': 'Expensive Product',
            'price': 10000,
            'stock': 1
        }
        form = ProductForm(data=form_data)
        self.assertFalse(form.is_valid())

测试 API (Django REST Framework)

# tests/test_api.py
from rest_framework.test import APITestCase
from rest_framework import status
from django.urls import reverse
class ProductAPITest(APITestCase):
    def setUp(self):
        self.product = Product.objects.create(
            name="API Product",
            price=199.99,
            stock=20
        )
    def test_list_products(self):
        """测试API产品列表"""
        url = reverse('product-api-list')
        response = self.client.get(url, format='json')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(len(response.data), 1)
    def test_create_product(self):
        """测试API创建产品"""
        url = reverse('product-api-list')
        data = {'name': 'New Product', 'price': 50, 'stock': 10}
        response = self.client.post(url, data, format='json')
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        self.assertEqual(Product.objects.count(), 2)
    def test_retrieve_product(self):
        """测试API获取单个产品"""
        url = reverse('product-api-detail', args=[self.product.id])
        response = self.client.get(url, format='json')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(response.data['name'], self.product.name)

测试需要认证的视图

from django.contrib.auth.models import User
from django.test import TestCase, Client
class AuthViewTest(TestCase):
    def setUp(self):
        self.client = Client()
        self.user = User.objects.create_user(
            username='testuser',
            password='testpass123'
        )
    def test_unauthenticated_access(self):
        """测试未认证访问"""
        url = reverse('dashboard')
        response = self.client.get(url)
        self.assertEqual(response.status_code, 302)  # 重定向到登录页
    def test_authenticated_access(self):
        """测试认证访问"""
        self.client.login(username='testuser', password='testpass123')
        url = reverse('dashboard')
        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)
        self.assertTemplateUsed(response, 'dashboard.html')
    def test_login_required_decorator(self):
        """测试login_required装饰器"""
        from django.contrib.auth.decorators import login_required
        @login_required
        def protected_view(request):
            return HttpResponse()
        request = self.client.get('/protected/')
        self.assertNotEqual(request.status_code, 200)

测试信号

# tests/test_signals.py
from django.test import TestCase
from unittest.mock import patch
from myapp.signals import product_created
class SignalTest(TestCase):
    @patch('myapp.signals.send_welcome_email')
    def test_product_created_signal(self, mock_send_email):
        """测试产品创建信号"""
        product = Product.objects.create(
            name="Signal Product",
            price=100
        )
        # 验证信号处理函数被调用
        mock_send_email.assert_called_once_with(product)
    def test_signal_receiver_count(self):
        """测试信号接收器数量"""
        from django.db.models.signals import post_save
        receivers = post_save.receivers
        product_receivers = [
            r for r in receivers 
            if r[1].__name__ == 'handle_product_save'
        ]
        self.assertEqual(len(product_receivers), 1)

使用测试工具

# test_with_tools.py
from django.test import TestCase, override_settings
from unittest.mock import patch, MagicMock
import tempfile
class ToolTest(TestCase):
    @override_settings(DEBUG=True)
    def test_debug_mode(self):
        """临时覆盖设置"""
        from django.conf import settings
        self.assertTrue(settings.DEBUG)
    @patch('myapp.services.calculate_tax')
    def test_mock_external_service(self, mock_calculate):
        """模拟外部服务"""
        mock_calculate.return_value = 15.00
        result = process_order(self.order)
        self.assertEqual(result['tax'], 15.00)
        mock_calculate.assert_called_once_with(self.order)
    def test_file_upload(self):
        """测试文件上传"""
        import io
        from django.core.files.uploadedfile import SimpleUploadedFile
        small_gif = (
            b'\x47\x49\x46\x38\x39\x61'
            b'\x01\x00\x01\x00\x00\x00'
            b'\x00\x21\xf9\x04\x00\x00'
            b'\x00\x00\x00\x2c\x00\x00'
            b'\x00\x00\x01\x00\x01\x00'
            b'\x00\x02\x02\x44\x01\x00'
            b'\x3b'
        )
        uploaded = SimpleUploadedFile(
            'small.gif',
            small_gif,
            content_type='image/gif'
        )
        response = self.client.post('/upload/', {'file': uploaded})
        self.assertEqual(response.status_code, 200)

组织和运行测试

目录结构

myapp/
├── tests/
│   ├── __init__.py
│   ├── test_models.py
│   ├── test_views.py
│   ├── test_forms.py
│   └── test_api.py

运行测试

# 运行所有测试
python manage.py test
# 运行特定app的测试
python manage.py test myapp
# 运行特定测试类
python manage.py test myapp.tests.test_models.ProductModelTest
# 运行特定测试方法
python manage.py test myapp.tests.test_models.ProductModelTest.test_product_creation
# 指定测试目录
python manage.py test myapp.tests
# 并行运行测试 (需要多个核心)
python manage.py test --parallel

覆盖率报告

# 安装覆盖率工具
pip install coverage
# 运行测试并生成报告
coverage run --source='myapp' manage.py test myapp
coverage report
coverage html  # 生成HTML报告

最佳实践

  1. 测试命名规范

    • 测试类:Test<Feature>
    • 测试方法:test_<action>_<expected_result>
  2. 保持测试独立

    • 每个测试方法应该独立运行
    • 不要依赖测试执行顺序
  3. 使用工厂模式

    # 使用factory_boy简化测试数据创建
    import factory
    from myapp.models import Product

class ProductFactory(factory.django.DjangoModelFactory): class Meta: model = Product

name = factory.Sequence(lambda n: f'Product {n}')
price = factory.Faker('pydecimal', left_digits=3, right_digits=2)
stock = factory.Faker('random_int', min=0, max=100)

在测试中使用

product = ProductFactory() product2 = ProductFactory(price=200.00)


4. **测试速度优化**
   - 使用 `setUpTestData()` 代替 `setUp()` 当数据不变时
   - 使用 `TestCase` 以外,可用 `TransactionTestCase` 但更慢
   - 合理使用 `@override_settings` 而不是修改配置文件
5. **异步测试**
```python
from django.test import TestCase
import asyncio
class AsyncTest(TestCase):
    async def test_async_view(self):
        response = await self.async_client.get('/async-endpoint/')
        self.assertEqual(response.status_code, 200)

Django 单元测试是确保代码质量的基石,通过系统化的测试覆盖,可以及早发现问题、简化重构过程,并为项目长期维护提供保障。

标签: Django单元测 试 pytest

抱歉,评论功能暂时关闭!