在DRF中,source
参数用于在序列化器字段和模型字段之间建立映射关系。本文将从用途和示例出发,帮助你理解其作用。
它的主要用途如下:
- 字段重命名
- 访问嵌套模型字段
- 处理多对多关系
1. 字段重命名
假如我们有如下model:
from django.db import modelsclass Order(models.Model):name = models.CharField(max_length=255)def __str__(self):return self.name
当前端需要order_name
字段名来替代name
时,我们可以创建如下的序列化器:
from rest_framework import serializers
from api.models import Orderclass OrderSerializer(serializers.ModelSerializer):order_name = serializers.CharField(source='name')class Meta:model = Orderfields = ['id', 'order_name']
在上面的示例中,我们使用source
字段创建了序列化器字段order_name
和模型字段name
之间的映射关系:当我们在进行反序列化的时候,order_name
将填充到name
中,当进行序列化时从查询集中获取的name
字段将被用于填充order_name
。
在下面的测试中,观察vaildate_data
和序列化的结果,你可以清晰的看到这一点:
from rest_framework.test import APITestCase
from .serializers import OrderSerializerclass TestSource(APITestCase):def test_source(self):user = User.objects.create_user(username='用户1', password='test')data = {'order_name': "测试订单",'notes': "用于测试",}s = OrderSerializer(data=data)s.is_valid(raise_exception=True)print(s.validated_data)# {'notes': '用于测试', 'name': '测试订单'}# name由order_name填充s.save(created_by=user)data = OrderSerializer(Order.objects.get(pk=1)).dataprint(data)# {'id': 1, 'order_name': '测试订单'}# order_name由name填充
2.访问嵌套(关系)模型字段
现在为刚刚的Order
模型添加一个外键:
from django.db import models
from django.contrib.auth.models import Userclass Order(models.Model):name = models.CharField(max_length=255)notes = models.CharField(max_length=255)created_by = models.ForeignKey(User, on_delete=models.CASCADE, null=True, blank=True)def __str__(self):return self.name
更改序列化器如下:
from rest_framework import serializers
from api.models import Orderclass OrderSerializer(serializers.ModelSerializer):created_by = serializers.StringRelatedField(read_only=True)order_name = serializers.CharField(source='name')create_user_name = serializers.CharField(source='created_by.username', read_only=True)class Meta:model = Orderfields = ['id', 'notes', 'created_by', 'create_user_name', 'order_name']
测试一下:
class TestSource(APITestCase):def test_source(self):user = User.objects.create_user(username='用户1', password='test')data = {'order_name': "测试订单",'notes': "用于测试",}s = OrderSerializer(data=data)s.is_valid(raise_exception=True)print(s.validated_data)# {'notes': '用于测试', 'name': '测试订单'}s.save(created_by=user)data = OrderSerializer(Order.objects.get(pk=1)).dataprint(data)# {'id': 1, 'notes': '用于测试', 'created_by': '用户1', 'create_user_name': '用户1', 'order_name': '测试订单'}
3.处理多对多关系
模型如下:
from django.db import modelsclass Blog(models.Model):title = models.CharField(max_length=255)detail = models.TextField()class BlogGroup(models.Model):name = models.CharField(max_length=255)blogs = models.ManyToManyField(Blog, related_name='groups')def __str__(self):return self.name
假如有这样的需求:我们希望在创建blog的时候可以选择传入一组group_ids
,模型中Blog
和Group
是多对多关系,如果使用重写ModelSerializer
的create
方法显得很麻烦,此时借助source
我们可以很轻松的完成这项工作:
from rest_framework import serializers
from .models import Blog, BlogGroupclass BlogSer(serializers.ModelSerializer):groups = serializers.StringRelatedField(read_only=True, many=True)group_ids = serializers.PrimaryKeyRelatedField(queryset=BlogGroup.objects.all(), write_only=True, source='groups',many=True)class Meta:model = Blogfields = ['id', 'title', 'detail', 'groups', 'group_ids']
测试一下:
from rest_framework.test import APITestCase
from .serializers import BlogSer
from .models import Blog, BlogGroupclass TestSource(APITestCase):def test_source_get(self):BlogGroup.objects.create(name='测试组1')BlogGroup.objects.create(name='测试组2')data = {'title': 'blog1','detail': 'my blog detail.','group_ids': []}s = BlogSer(data=data)s.is_valid(raise_exception=True)print(s.validated_data)# {'title': 'blog1', 'detail': 'my blog detail.', 'groups': [<BlogGroup: 测试组1>, <BlogGroup: 测试组2>]}s.save()print(BlogSer(Blog.objects.get(pk=1)).data)# {'id': 1, 'title': 'blog1', 'detail': 'my blog detail.', 'groups': ['测试组1', '测试组2']}
上述的做法缺点也很明显:如果要创建的数据较多,则会导致大量的查询:
>>> data = {'title': 'blog1','detail': 'my blog detail.','group_ids': [1, 2, 3, 4, 5, 6, 7, 8, 9]}
>>> s = BlogSer(data=data)
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" LIMIT 21; args=(); alias=default
>>> s.is_valid(raise_exception=True)
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 1 LIMIT 21; args=(1,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 2 LIMIT 21; args=(2,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 3 LIMIT 21; args=(3,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 4 LIMIT 21; args=(4,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 5 LIMIT 21; args=(5,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 6 LIMIT 21; args=(6,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 7 LIMIT 21; args=(7,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 8 LIMIT 21; args=(8,); alias=default
# (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" WHERE "api_bloggroup"."id" = 9 LIMIT 21; args=(9,); alias=default
True
>>> (0.000) SELECT "api_bloggroup"."id", "api_bloggroup"."name" FROM "api_bloggroup" LIMIT 21; args=(); alias=default
如果你比较在意性能,或者明确知道此接口可能经常批量的处理大量的数据,还是自己重写create
、update
方法为好:
from rest_framework import serializers
from .models import Blog, BlogGroupclass BlogSer(serializers.ModelSerializer):groups = serializers.StringRelatedField(read_only=True, many=True)group_ids = serializers.ListField(child=serializers.IntegerField(), write_only=True)class Meta:model = Blogfields = ['id', 'title', 'detail', 'groups', 'group_ids']def validate_group_ids(self, value):existing_ids = set(BlogGroup.objects.filter(id__in=value).values_list('id', flat=True))invalid_ids = set(value) - existing_idsif invalid_ids:raise serializers.ValidationError(f"Invalid group ids: {', '.join(map(str, invalid_ids))}")return valuedef create(self, validated_data):group_ids = validated_data.pop('group_ids')blog = Blog.objects.create(**validated_data)blog.groups.set(BlogGroup.objects.filter(id__in=group_ids))return blogdef update(self, instance, validated_data):group_ids = validated_data.pop('group_ids', None)for attr, value in validated_data.items():setattr(instance, attr, value)if group_ids is not None:instance.groups.set(BlogGroup.objects.filter(id__in=group_ids))instance.save()return instance