scikit-learn中Column Transformer中的全局变量

huangapple go评论54阅读模式
英文:

Global variable in Column Transformer scikit-learn

问题

在这个代码片段中,我构建了一个转换器(transformer),用于插入一列数据,但似乎修改了原始变量。实际上,我无法重复进行fit操作。这是否是预期行为?这会破坏例如grid_search的行为,因为它会尝试为每个网格拟合(insert)列。

from sklearn.base import BaseEstimator, TransformerMixin

class customColumnTransformer(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        X.insert(0, 'newCol', 1)
        return X

df = pd.DataFrame([[1, 2], [3, 4]])
display(customColumnTransformer().fit_transform(df))
display(df)
# display(customColumnTransformer().fit_transform(df)) 这会生成一个错误 ValueError: 无法插入 newCol,因为它已经存在

我找到的唯一解决方法是使用:

customColumnTransformer().fit_transform(df.copy())
英文:

in this snippet I built a transformer that insert one column, but it seems to modify the original variable. In fact I cannot repeat the fit operation. Is this expected? This breaks the behavior of grid_search for example because it tries to insert the column for every grid fit.

from sklearn.base import BaseEstimator,TransformerMixin
class customColumnTransformer(BaseEstimator,TransformerMixin):
    def fit(self,X,y=None):
        return self
    
    def transform(self,X):
        X.insert(0,'newCol',1)
        return X
    
df = pd.DataFrame([[1,2],[3,4]])
display(customColumnTransformer().fit_transform(df))
display(df)
# display(customColumnTransformer().fit_transform(df))   THIS GENERATE AN ERROR  ValueError: cannot insert newCol, already exists

The only solution I found is to use
customColumnTransformer().fit_transform(df.copy())

答案1

得分: 1

当您在 transform 函数内部使用 insert 方法修改 DataFrame X 时,这种行为是完全符合预期的。它会原地修改 DataFrame。因此,如果在同一个 DataFrame 上连续调用 fit_transform,会导致错误,因为列 'newCol' 已经存在。

为了解决这个问题并避免修改原始 DataFrame,您可以在插入新列之前创建一个它的副本。这样,每次调用 fit_transform 都会在一个单独的副本上操作,不会影响原始 DataFrame。

下面是修改后的代码版本:

from sklearn.base import BaseEstimator, TransformerMixin
import pandas as pd

class CustomColumnTransformer(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self

    def transform(self, X):
        X_copy = X.copy()  # 创建 DataFrame 的副本
        X_copy.insert(0, 'newCol', 1)
        return X_copy

df = pd.DataFrame([[1, 2], [3, 4]])
display(CustomColumnTransformer().fit_transform(df))
display(df)

通过使用 X_copy = X.copy(),您生成了一个新的 DataFrame 对象 X_copy,可以独立修改,而不会影响原始的 df DataFrame。因此,您可以多次重复使用转换器而不会遇到“已存在”错误。

这个修改后的代码允许您在同一个 DataFrame 上多次调用 fit_transform 而不会生成任何错误。

英文:

Certainly! It's completely expected behavior. When you use the insert method to modify the DataFrame X within your transform function, it modifies the DataFrame in-place. Consequently, subsequent calls to fit_transform using the same DataFrame will result in an error since the column 'newCol' already exists.

To circumvent this issue and avoid modifying the original DataFrame, you can create a copy of it before inserting the new column. This way, each call to fit_transform will operate on a separate copy, leaving the original DataFrame intact.

Here's an updated version of your code that incorporates this change:

from sklearn.base import BaseEstimator, TransformerMixin
import pandas as pd

class CustomColumnTransformer(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        X_copy = X.copy()  # Create a copy of the DataFrame
        X_copy.insert(0, 'newCol', 1)
        return X_copy

df = pd.DataFrame([[1, 2], [3, 4]])
display(CustomColumnTransformer().fit_transform(df))
display(df)

By utilizing X_copy = X.copy(), you generate a new DataFrame object X_copy that can be modified independently without affecting the original df DataFrame. Consequently, you can reuse the transformer multiple times without encountering the "already exists" error.

This revised code allows you to call fit_transform repeatedly on the same DataFrame without generating any errors.

huangapple
  • 本文由 发表于 2023年6月8日 18:24:18
  • 转载请务必保留本文链接:https://go.coder-hub.com/76430896.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定