Python案例如何实现参数校验?

wen python案例 17

本文目录导读:

Python案例如何实现参数校验?

  1. 使用装饰器进行参数校验
  2. 使用Pydantic库(推荐)
  3. 使用dataclasses结合自定义校验
  4. 使用marshmallow库
  5. 综合示例:完整的参数校验系统
  6. 最佳实践建议

在Python中实现参数校验有多种方法,从简单到复杂,我为你介绍几种常见方案:

使用装饰器进行参数校验

from functools import wraps
from typing import Any, Callable
def validate_params(func: Callable) -> Callable:
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 校验参数类型
        if len(args) > 0:
            arg = args[0]
            if not isinstance(arg, (int, float)):
                raise TypeError(f"参数必须为数字类型,当前类型为: {type(arg).__name__}")
            if arg < 0:
                raise ValueError("参数不能为负数")
        # 校验关键字参数
        if 'name' in kwargs:
            if not isinstance(kwargs['name'], str):
                raise TypeError(f"name必须为字符串类型")
            if len(kwargs['name']) == 0:
                raise ValueError("name不能为空")
        return func(*args, **kwargs)
    return wrapper
# 使用示例
@validate_params
def calculate_square(number: int):
    return number ** 2
@validate_params
def greet(name: str):
    return f"Hello, {name}!"
# 测试
try:
    print(calculate_square(5))  # 正常
    print(calculate_square(-1))  # 报错
except ValueError as e:
    print(f"错误: {e}")
try:
    print(greet("Alice"))  # 正常
    print(greet(""))  # 报错
except ValueError as e:
    print(f"错误: {e}")

使用Pydantic库(推荐)

from pydantic import BaseModel, Field, validator
from typing import Optional, List
from datetime import datetime
class User(BaseModel):
    # 字段定义和校验
    username: str = Field(..., min_length=3, max_length=20)
    email: str
    age: int = Field(..., ge=0, le=150)
    password: str = Field(..., min_length=6)
    tags: List[str] = []
    created_at: Optional[datetime] = None
    # 自定义校验器
    @validator('email')
    def validate_email(cls, v):
        if '@' not in v:
            raise ValueError('邮箱格式不正确')
        return v.lower()
    @validator('password')
    def validate_password(cls, v):
        if not any(char.isdigit() for char in v):
            raise ValueError('密码必须包含数字')
        if not any(char.isupper() for char in v):
            raise ValueError('密码必须包含大写字母')
        return v
# 使用示例
def create_user(user_data: dict):
    """创建用户,自动进行参数校验"""
    try:
        user = User(**user_data)
        return {"status": "success", "data": user.dict()}
    except Exception as e:
        return {"status": "error", "message": str(e)}
# 测试
valid_user = {
    "username": "john_doe",
    "email": "john@example.com",
    "age": 25,
    "password": "SecurePass123",
    "tags": ["python", "developer"]
}
print(create_user(valid_user))
# 测试无效数据
invalid_user = {
    "username": "jd",  # 太短
    "email": "invalid",  # 无效邮箱
    "age": 200,  # 年龄超范围
    "password": "weak"  # 密码太弱
}
print(create_user(invalid_user))

使用dataclasses结合自定义校验

from dataclasses import dataclass, field
from typing import List, Optional
import re
@dataclass
class Product:
    name: str
    price: float
    quantity: int = 0
    tags: List[str] = field(default_factory=list)
    sku: Optional[str] = None
    def __post_init__(self):
        """数据校验逻辑"""
        # 名称校验
        if not self.name or len(self.name.strip()) == 0:
            raise ValueError("产品名称不能为空")
        if len(self.name) > 100:
            raise ValueError("产品名称不能超过100个字符")
        # 价格校验
        if self.price <= 0:
            raise ValueError("产品价格必须大于0")
        if self.price > 1000000:
            raise ValueError("产品价格不能超过1,000,000")
        # 数量校验
        if self.quantity < 0:
            raise ValueError("产品数量不能为负数")
        if self.quantity > 10000:
            raise ValueError("单次操作数量不能超过10,000")
        # SKU格式校验(如果有)
        if self.sku:
            if not re.match(r'^[A-Z]{2}-\d{4}-[A-Z]{3}$', self.sku):
                raise ValueError("SKU格式错误,示例: PR-2024-ABC")
# 使用示例
def process_product_order(product_data: dict):
    """处理产品订单"""
    try:
        product = Product(**product_data)
        return {
            "status": "success",
            "message": f"产品 {product.name} 校验通过",
            "total_value": product.price * product.quantity
        }
    except ValueError as e:
        return {"status": "error", "message": str(e)}
# 测试
print(process_product_order({
    "name": "笔记本电脑",
    "price": 5999.99,
    "quantity": 10,
    "sku": "EL-2024-LPT"
}))
print(process_product_order({
    "name": "",  # 无效名称
    "price": -100,  # 无效价格
    "quantity": -1  # 无效数量
}))

使用marshmallow库

from marshmallow import Schema, fields, validate, ValidationError
class OrderSchema(Schema):
    """订单参数校验"""
    order_id = fields.String(required=True, validate=validate.Length(min=6, max=20))
    customer_name = fields.String(required=True, validate=validate.Length(min=2, max=50))
    email = fields.Email(required=True)
    amount = fields.Float(required=True, validate=validate.Range(min=0.01))
    items = fields.List(fields.Dict(), required=True, validate=validate.Length(min=1))
    shipping_address = fields.String(required=True)
    discount_code = fields.String(validate=validate.Length(max=20), missing=None)
class ItemSchema(Schema):
    """商品项校验"""
    product_id = fields.String(required=True)
    quantity = fields.Integer(required=True, validate=validate.Range(min=1, max=100))
    unit_price = fields.Float(required=True, validate=validate.Range(min=0.01))
# 使用示例
def create_order(order_data: dict):
    """创建订单并进行参数校验"""
    schema = OrderSchema()
    try:
        validated_data = schema.load(order_data)
        return {"status": "success", "data": validated_data}
    except ValidationError as err:
        return {"status": "error", "errors": err.messages}
# 测试
valid_order = {
    "order_id": "ORD-2024-001",
    "customer_name": "张三",
    "email": "zhangsan@example.com",
    "amount": 199.99,
    "items": [{"product_id": "P001", "quantity": 2, "unit_price": 99.995}],
    "shipping_address": "北京市朝阳区xxx",
    "discount_code": "SAVE10"
}
print(create_order(valid_order))

综合示例:完整的参数校验系统

from typing import Any, Dict, List, Optional, Callable
from functools import wraps
import re
from dataclasses import dataclass
from enum import Enum
class ValidationLevel(Enum):
    STRICT = "strict"
    MODERATE = "moderate"
    LOOSE = "loose"
@dataclass
class ValidationRule:
    """校验规则定义"""
    field_name: str
    required: bool = True
    field_type: type = str
    min_length: Optional[int] = None
    max_length: Optional[int] = None
    min_value: Optional[float] = None
    max_value: Optional[float] = None
    regex: Optional[str] = None
    custom_validator: Optional[Callable] = None
class ParameterValidator:
    """参数校验器"""
    def __init__(self, rules: List[ValidationRule], level: ValidationLevel = ValidationLevel.MODERATE):
        self.rules = rules
        self.level = level
    def validate(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """执行参数校验"""
        errors = {}
        for rule in self.rules:
            field_name = rule.field_name
            value = params.get(field_name)
            # 1. 检查必需字段
            if value is None:
                if rule.required:
                    errors[field_name] = f"字段 '{field_name}' 是必需的"
                continue
            # 2. 类型检查
            if not isinstance(value, rule.field_type):
                errors[field_name] = f"字段 '{field_name}' 应为 {rule.field_type.__name__} 类型"
                continue
            # 3. 长度检查
            if rule.min_length is not None and isinstance(value, str) and len(value) < rule.min_length:
                errors[field_name] = f"字段 '{field_name}' 最小长度为 {rule.min_length}"
            if rule.max_length is not None and isinstance(value, str) and len(value) > rule.max_length:
                errors[field_name] = f"字段 '{field_name}' 最大长度为 {rule.max_length}"
            # 4. 数值范围检查
            if rule.min_value is not None and isinstance(value, (int, float)) and value < rule.min_value:
                errors[field_name] = f"字段 '{field_name}' 最小值应为 {rule.min_value}"
            if rule.max_value is not None and isinstance(value, (int, float)) and value > rule.max_value:
                errors[field_name] = f"字段 '{field_name}' 最大值应为 {rule.max_value}"
            # 5. 正则表达式检查
            if rule.regex and isinstance(value, str):
                if not re.match(rule.regex, value):
                    errors[field_name] = f"字段 '{field_name}' 格式不正确"
            # 6. 自定义校验器
            if rule.custom_validator:
                try:
                    rule.custom_validator(value)
                except ValueError as e:
                    errors[field_name] = str(e)
        return errors
def with_validation(rules: List[ValidationRule]):
    """参数校验装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            validator = ParameterValidator(rules)
            errors = validator.validate(kwargs)
            if errors:
                return {"status": "error", "errors": errors}
            return func(*args, **kwargs)
        return wrapper
    return decorator
# 使用示例
def validate_email(email: str):
    """自定义邮箱校验"""
    if '@' not in email:
        raise ValueError("邮箱格式无效")
    if email.count('@') != 1:
        raise ValueError("邮箱格式无效")
    parts = email.split('@')
    if len(parts[0]) == 0 or len(parts[1]) == 0:
        raise ValueError("邮箱格式无效")
# 定义校验规则
user_rules = [
    ValidationRule("username", required=True, field_type=str, min_length=3, max_length=20),
    ValidationRule("email", required=True, field_type=str, custom_validator=validate_email),
    ValidationRule("age", required=True, field_type=int, min_value=0, max_value=150),
    ValidationRule("phone", required=False, field_type=str, regex=r'^1[3-9]\d{9}$'),
]
@with_validation(user_rules)
def register_user(**user_data):
    """注册用户"""
    return {
        "status": "success",
        "message": f"用户 {user_data['username']} 注册成功",
        "user_info": user_data
    }
# 测试
print("=== 测试1: 有效数据 ===")
result = register_user(
    username="john_doe",
    email="john@example.com",
    age=25,
    phone="13800138000"
)
print(result)
print("\n=== 测试2: 无效数据 ===")
result = register_user(
    username="jd",  # 太短
    email="invalid-email",  # 无效邮箱
    age=200,  # 年龄超范围
    phone="12345"  # 无效手机号
)
print(result)

最佳实践建议

  1. 选择合适的校验方式

    • 简单项目:使用装饰器或dataclass
    • 中型项目:使用Pydantic
    • 复杂API:使用marshmallow或完整校验系统
  2. 校验原则

    • 及早校验:在数据进入核心逻辑前完成校验
    • 友好提示:提供清晰的错误信息
    • 安全性:检查SQL注入、XSS等安全风险
  3. 性能考虑

    • 生产环境建议启用缓存
    • 批量数据处理时考虑异步校验
  4. 测试覆盖

    • 测试边界值
    • 测试异常情况
    • 测试不同类型的数据

这些方案可以根据项目需求灵活选择,建议从简单开始,逐步升级到更完善的校验方案。

抱歉,评论功能暂时关闭!