本文目录导读:
在Python中,模型保存主要涉及机器学习/深度学习框架,以下是几种常见框架的保存案例:
Scikit-learn 模型保存
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# 训练模型
iris = load_iris()
X, y = iris.data, iris.target
model = RandomForestClassifier()
model.fit(X, y)
# 方法1:使用pickle保存
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)
# 方法2:使用joblib(推荐,效率更高)
import joblib
joblib.dump(model, 'model.joblib')
# 加载模型
loaded_model = joblib.load('model.joblib')
predictions = loaded_model.predict(X[:5])
print(predictions)
PyTorch 模型保存
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
# 方法1:保存整个模型
torch.save(model, 'model_full.pth')
loaded_model = torch.load('model_full.pth')
# 方法2:仅保存参数(推荐)
torch.save(model.state_dict(), 'model_weights.pth')
model_loaded = SimpleModel()
model_loaded.load_state_dict(torch.load('model_weights.pth'))
# 方法3:保存检查点(包含优化器状态)
checkpoint = {
'epoch': 10,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': 0.5
}
torch.save(checkpoint, 'checkpoint.pth')
TensorFlow/Keras 模型保存
import tensorflow as tf
# 创建模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(2, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy')
# 方法1:保存整个模型(包含架构、权重、优化器状态)
model.save('my_model.keras') # 或 .h5
loaded_model = tf.keras.models.load_model('my_model.keras')
# 方法2:仅保存权重
model.save_weights('model_weights.weights.h5')
model_new = tf.keras.Sequential([...]) # 需要相同的架构
model_new.load_weights('model_weights.weights.h5')
# 方法3:保存为SavedModel格式(生产部署推荐)
model.save('saved_model/', save_format='tf')
XGBoost 模型保存
import xgboost as xgb
from sklearn.datasets import load_iris
# 训练模型
iris = load_iris()
dtrain = xgb.DMatrix(iris.data, label=iris.target)
params = {'max_depth': 3, 'eta': 0.1, 'objective': 'multi:softmax'}
model = xgb.train(params, dtrain, num_boost_round=10)
# 保存模型
model.save_model('xgb_model.json')
# 加载模型
loaded_model = xgb.Booster()
loaded_model.load_model('xgb_model.json')
最佳实践建议
保存检查点(训练过程中的保存)
import torch
class ModelTrainer:
def __init__(self, model, optimizer, save_dir='checkpoints'):
self.model = model
self.optimizer = optimizer
self.save_dir = save_dir
def save_checkpoint(self, epoch, loss, filename='checkpoint.pth'):
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': loss
}
torch.save(checkpoint, f'{self.save_dir}/{filename}')
def load_checkpoint(self, filename='checkpoint.pth'):
checkpoint = torch.load(f'{self.save_dir}/{filename}')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['loss']
模型版本控制
import datetime
def save_model_with_version(model, base_name='model'):
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'{base_name}_{timestamp}.pkl'
joblib.dump(model, filename)
return filename
注意事项
- 安全性:避免使用pickle加载不受信任的模型文件
- 兼容性:确保加载模型时使用相同的框架版本
- 文件格式选择:
- Scikit-learn:推荐使用joblib
- PyTorch:推荐保存state_dict
- TensorFlow:生产环境推荐SavedModel格式
- 跨平台:注意Windows和Linux路径差异
- 大数据量:考虑使用HDF5或ONNX格式
这些案例涵盖了Python中主流的模型保存方式,你可以根据具体框架和需求选择合适的保存方法。
标签: 案例实现