本文目录导读:
- PyTorch 模型加载(最常用)
- TensorFlow / Keras 模型加载
- Hugging Face Transformers(NLP 模型)
- Scikit-learn 模型加载(传统机器学习)
- ONNX 模型加载(跨平台/跨框架推理)
- TFLite(移动端/边缘设备)
- 特殊情况:多 GPU 模型加载
Python模型加载案例非常丰富,主要取决于你使用的框架(如PyTorch、TensorFlow、Scikit-learn)、模型格式(如.pth、.h5、.onnx)以及部署场景(推理、迁移学习、服务化)。
下面按常用框架分类,列举几个典型的加载案例及代码示例。
PyTorch 模型加载(最常用)
PyTorch 通常推荐保存模型的状态字典(state_dict),加载时需先实例化相同的模型结构。
案例 A:加载完整模型(包含结构,不推荐用于生产)
import torch
import torchvision.models as models
# 1. 保存时使用 torch.save(model, 'model.pth')
# 2. 加载
model = torch.load('model.pth')
model.eval() # 切换到评估模式
案例 B:加载 state_dict(标准做法)
import torch
import torchvision.models as models
# 1. 先定义与训练时完全相同的模型结构
model = models.resnet18(weights=None, num_classes=10)
# 2. 加载权重
state_dict = torch.load('model_state_dict.pth', map_location='cpu') # 防止GPU不存在
model.load_state_dict(state_dict)
# 3. 切换到评估模式
model.eval()
# 4. 推理
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(dummy_input)
案例 C:加载 Checkpoint(包含优化器状态,用于继续训练)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
TensorFlow / Keras 模型加载
案例 A:加载整个模型(.h5 或 SavedModel 格式)
import tensorflow as tf
# 加载保存的完整模型(结构+权重+优化器)
model = tf.keras.models.load_model('my_model.h5')
# 推理
import numpy as np
dummy_input = np.random.rand(1, 224, 224, 3)
predictions = model.predict(dummy_input)
案例 B:仅加载权重(需先定义模型结构)
# 先定义相同结构
model = tf.keras.Sequential([...]) # 或使用函数式API
# 加载权重
model.load_weights('model_weights.h5')
# 编译(如果需要进行评估或继续训练)
model.compile(optimizer='adam', loss='categorical_crossentropy')
Hugging Face Transformers(NLP 模型)
案例 A:从本地文件加载预训练模型
from transformers import AutoModel, AutoTokenizer
# 假设已经下载或手动保存了模型到本地目录
model = AutoModel.from_pretrained('./my_bert_model/')
tokenizer = AutoTokenizer.from_pretrained('./my_bert_model/')
# 推理
inputs = tokenizer("Hello, world!", return_tensors="pt")
outputs = model(**inputs)
案例 B:通过模型库名称加载(自动下载)
from transformers import pipeline
# 一行加载并推理
classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
result = classifier("I love this movie!")
Scikit-learn 模型加载(传统机器学习)
import joblib
# 加载之前用 joblib.dump 保存的模型
model = joblib.load('trained_model.pkl')
# 推理
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)
ONNX 模型加载(跨平台/跨框架推理)
ONNX 允许你在不同框架(如 PyTorch → ONNX → TensorRT)之间交换模型。
import onnxruntime as ort
import numpy as np
# 创建一个 ONNX Runtime 推理会话
ort_session = ort.InferenceSession('model.onnx')
# 获取输入输出名称
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# 准备输入数据(必须与 ONNX 模型要求的 shape 一致)
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 推理
outputs = ort_session.run([output_name], {input_name: input_data})
TFLite(移动端/边缘设备)
import tensorflow as tf import numpy as np # 加载 TFLite 模型 interpreter = tf.lite.Interpreter(model_path='model.tflite') interpreter.allocate_tensors() # 获取输入输出 tensors input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # 填充输入 input_data = np.random.randn(1, 224, 224, 3).astype(np.float32) interpreter.set_tensor(input_details[0]['index'], input_data) # 推理 interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index'])
特殊情况:多 GPU 模型加载
如果模型是使用 DataParallel 训练保存的,权重键名会带有 module. 前缀。
import torch
# 方法1:移除前缀
state_dict = torch.load('model.pth')
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("module.", "") # 去掉module.前缀
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
# 方法2:直接加载到单个GPU或CPU(PyTorch 1.1+ 会自动处理)
model = torch.nn.DataParallel(model) # 重新包装
model.load_state_dict(torch.load('model.pth'))
| 场景 | 推荐格式 | 加载方式 |
|---|---|---|
| PyTorch 训练/推理 | .pth, .pt |
torch.load() + model.load_state_dict() |
| TensorFlow 部署 | .h5 或 SavedModel |
tf.keras.models.load_model() |
| 传统机器学习 (sklearn) | .pkl |
joblib.load() |
| 跨框架/移动端 | .onnx, .tflite |
onnxruntime.InferenceSession() / tf.lite.Interpreter() |
| NLP 模型 (Hugging Face) | config.json, pytorch_model.bin |
AutoModel.from_pretrained() |
小提示: 生产环境中,建议将模型转换为 ONNX 或 TensorRT(针对NVIDIA GPU)以获得最佳性能和可移植性。