哪个求解器引发了这个警告?

huangapple go评论87阅读模式
英文:

Which solver is throwing this warning?

问题

这个警告弹出了10次。行号在456和305之间变化:

C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
warn('The line search algorithm did not converge', LineSearchWarning)

我正在使用以下参数运行一个网格搜索:

logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
"random_state": [0]
}

所以问题是哪个求解器(solver)引发了警告?能够确定吗?

英文:

This warning pops up 10 times. The line number varies between 456 and 305:

  1. C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize\_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
  2. warn('The line search algorithm did not converge', LineSearchWarning)

I'm running a grid search with these parameters:

  1. logistic_regression_grid = {
  2. "class_weight": ["balanced"],
  3. "max_iter": [100000],
  4. "solver": ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
  5. "random_state": [0]
  6. }

So, the question is which solver is throwing the warning? Is it possible to determine that?

答案1

得分: 1

以下是您提供的代码的中文翻译部分:

  1. 我使用了鸢尾花数据集`max_iter=10`设置为故意引发收敛警告由于您只对求解器感兴趣我在不使用网格搜索的情况下循环遍历了这些求解器并且能够使用`warnings`库和`sklearn.exceptions`包打印出哪个求解器不收敛以下是我的代码
  2. ```python
  3. import warnings
  4. import numpy as np
  5. from sklearn.linear_model import LogisticRegression
  6. from sklearn.exceptions import ConvergenceWarning
  7. from sklearn.datasets import load_iris
  8. from sklearn.model_selection import train_test_split
  9. # 逻辑回归参数网格
  10. logistic_regression_grid = {
  11. "class_weight": ["balanced"],
  12. "max_iter": [100000],
  13. "solver": ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
  14. "random_state": [0]
  15. }
  16. # 载入鸢尾花数据集
  17. iris = load_iris()
  18. X, y = iris.data, iris.target
  19. # 将数据拆分为训练集和测试集
  20. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
  21. # 循环遍历求解器并捕获警告
  22. for solver in logistic_regression_grid["solver"]:
  23. with warnings.catch_warnings(record=True) as w:
  24. warnings.simplefilter("always")
  25. # 使用当前求解器拟合逻辑回归模型
  26. model = LogisticRegression(class_weight="balanced", max_iter=10, solver=solver, random_state=0)
  27. model.fit(X_train, y_train)
  28. # 检查是否生成了任何警告
  29. if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
  30. print(f"求解器 '{solver}' 未收敛。")

这是我得到的输出:

  1. 求解器 'lbfgs' 未收敛
  2. 求解器 'newton-cg' 未收敛
  3. 求解器 'sag' 未收敛
  4. 求解器 'saga' 未收敛
  1. 希望这能帮助您理解代码的功能和输出。
  2. <details>
  3. <summary>英文:</summary>
  4. I used the iris set and I set `max_iter=10` to purposefully induce a convergence warning. Since you are interested only in the solvers I looped over the solvers without using grid search and I was able to print which solver does not converge using the `warnings` library and the `sklearn.exceptions` package. Here is my code:

import warnings
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

Your logistic regression grid

logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
"random_state": [0]
}

Load the Iris dataset

iris = load_iris()
X, y = iris.data, iris.target

Split the data into training and testing sets

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

Loop over the solvers and capture warnings

for solver in logistic_regression_grid["solver"]:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

  1. # Fit logistic regression model with the current solver
  2. model = LogisticRegression(class_weight=&quot;balanced&quot;, max_iter=10, solver=solver, random_state=0)
  3. model.fit(X_train, y_train)
  4. # Check if any warning was generated
  5. if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
  6. print(f&quot;Solver &#39;{solver}&#39; did not converge.&quot;)
  1. Here is the output I get:

Solver 'lbfgs' did not converge.
Solver 'newton-cg' did not converge.
Solver 'sag' did not converge.
Solver 'saga' did not converge.

  1. </details>

huangapple
  • 本文由 发表于 2023年3月15日 20:20:23
  • 转载请务必保留本文链接:https://go.coder-hub.com/75744615.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定