英文:
Which solver is throwing this warning?
问题
这个警告弹出了10次。行号在456和305之间变化:
C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
warn('The line search algorithm did not converge', LineSearchWarning)
我正在使用以下参数运行一个网格搜索:
logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
"random_state": [0]
}
所以问题是哪个求解器(solver)引发了警告?能够确定吗?
英文:
This warning pops up 10 times. The line number varies between 456 and 305:
C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize\_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
warn('The line search algorithm did not converge', LineSearchWarning)
I'm running a grid search with these parameters:
logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
"random_state": [0]
}
So, the question is which solver is throwing the warning? Is it possible to determine that?
答案1
得分: 1
以下是您提供的代码的中文翻译部分:
我使用了鸢尾花数据集,将`max_iter=10`设置为故意引发收敛警告。由于您只对求解器感兴趣,我在不使用网格搜索的情况下循环遍历了这些求解器,并且能够使用`warnings`库和`sklearn.exceptions`包打印出哪个求解器不收敛。以下是我的代码:
```python
import warnings
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 逻辑回归参数网格
logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
"random_state": [0]
}
# 载入鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target
# 将数据拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 循环遍历求解器并捕获警告
for solver in logistic_regression_grid["solver"]:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# 使用当前求解器拟合逻辑回归模型
model = LogisticRegression(class_weight="balanced", max_iter=10, solver=solver, random_state=0)
model.fit(X_train, y_train)
# 检查是否生成了任何警告
if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
print(f"求解器 '{solver}' 未收敛。")
这是我得到的输出:
求解器 'lbfgs' 未收敛。
求解器 'newton-cg' 未收敛。
求解器 'sag' 未收敛。
求解器 'saga' 未收敛。
希望这能帮助您理解代码的功能和输出。
<details>
<summary>英文:</summary>
I used the iris set and I set `max_iter=10` to purposefully induce a convergence warning. Since you are interested only in the solvers I looped over the solvers without using grid search and I was able to print which solver does not converge using the `warnings` library and the `sklearn.exceptions` package. Here is my code:
import warnings
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
Your logistic regression grid
logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
"random_state": [0]
}
Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
Loop over the solvers and capture warnings
for solver in logistic_regression_grid["solver"]:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Fit logistic regression model with the current solver
model = LogisticRegression(class_weight="balanced", max_iter=10, solver=solver, random_state=0)
model.fit(X_train, y_train)
# Check if any warning was generated
if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
print(f"Solver '{solver}' did not converge.")
Here is the output I get:
Solver 'lbfgs' did not converge.
Solver 'newton-cg' did not converge.
Solver 'sag' did not converge.
Solver 'saga' did not converge.
</details>
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论