自定义颜色的matplotlib线条,但图例不会更新。

huangapple go评论147阅读模式
英文:

Custom colored lines in matplotlib but legend won't update

问题

我正尝试从一个Pandas数据框中绘制线条,这个数据框是我的数据的一个透视表。有12条线,每个月一条,我不想有重复的颜色。因此,我决定手动设置颜色。在绘图中,线条的颜色已经更新,但当我创建图例时,它仍然显示旧的颜色。

以下是我的代码片段。我想更改12个子图中的4个子图的颜色:

fig, axs = plt.subplots(4, 3)

for i, country in enumerate(countries):
    
    pivot_yearly = pd.pivot_table(df[country].to_frame(), index=df.index.month, columns=df.index.year)
    pivot_monthly = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month) # 问题出在这里。
    pivot_dyyr = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.year)
    
    axs[i,0].plot(pivot_yearly)

    cm = plt.get_cmap('gist_rainbow')
    cNorm = colors.Normalize(vmin=0, vmax=11)
    scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)

    axs[i,1].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)])
    axs[i,1].plot(pivot_monthly)

    axs[i,3].plot(pivot_dyyr)

labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg1 = fig.legend(pivot_yearly.columns.get_level_values('Date'), ncols=3)
leg2 = fig.legend(pivot_monthly.columns.get_level_values('Date'), labels=labels, ncols=3)
leg3 = fig.legend(pivot_dyyr.columns.get_level_values('Date'), ncols=3)

plt.tight_layout()
plt.show()

不幸的是,我无法分享我的数据,但它涉及4个国家和索引'Date'作为DateTimeIndex。

当我运行axs[i,1].set_prop_cycle()这一行时,线条颜色已更新,但图例没有更新。

英文:

I am trying to plot lines out of a pandas dataframe, a pivot table of my data. There's 12 lines, one for each month of the year, and I don't want any duplicate colors. So I decided to set the colors manually. In the plot, the color of the lines is updated, but when I create a legend, it still displays the old colors.

Below is a snippet of my code. I want to change the colors for 4 out of 12 subplots:

fig, axs = plt.subplots(4, 3)

for i, country in enumerate(countries):
    
    pivot_yearly = pd.pivot_table(df[country].to_frame(), index=df.index.month, columns=df.index.year)
    pivot_monthly = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month) # The problem is with this one.
    pivot_dyyr = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.year)
    
    axs[i,0].plot(pivot_yearly)

    cm = plt.get_cmap('gist_rainbow')
    cNorm = colors.Normalize(vmin=0, vmax=11)
    scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)

    axs[i,1].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)])
    axs[i,1].plot(pivot_monthly)

    axs[i,3].plot(pivot_dyyr)

labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg1 = fig.legend(pivot_yearly.columsn.get_level_values('Date'), ncols=3)
leg2 = fig.legend(pivot_monthly.columns.get_level_values('Date'), labels=labels, ncols=3)
leg3 = fig.legend(pivot_dyyr.columns.get_level_values('Date'), ncols=3)

plt.tight_layout()
plt.show()

Unfortunately I cannot share my data, but it concerns 4 countries and the index 'Date' as a DateTimeIndex.

When I run the line axs[i,1].set_prop_cycle() the line colors are updated, but the legend is not.

答案1

得分: 1

由于您没有提供任何数据,我使用了healthexp数据集,因为它包含了6个国家和年份的数据。然后,我将其转换为我认为符合您要求的格式,将每个国家转换为一列并绘制了图表。由于此数据集中仅有6个国家,我将您的4x3图表更改为3x2。请检查是否符合您的要求。

fig, axs = plt.subplots(2, 3, figsize=(12,5))

## 我的数据
df=sns.load_dataset('healthexp')
df['Month'] = np.random.randint(1,13, size=(len(df)))  ## 创建随机月份
df['Year']= pd.to_datetime(df['Year'].astype(str)+df['Month'].astype(str), format='%Y%m') # 将年份更改为日期时间格式
df=df.pivot_table(index="Year", columns="Country", values="Spending_USD", fill_value=0) ## 将每个国家更改为列
countries=df.columns  ## 猜测您想要每个国家一个图

for i, country in enumerate(countries):
    pivot = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month)
    cm = plt.get_cmap('gist_rainbow')
    cNorm = colors.Normalize(vmin=0, vmax=12)
    scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
    axs[i%2,int(i/2)].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)])  ## 注意我将axs更改为3x2
    axs[i%2,int(i/2)].plot(pivot)  ## 注意我将axs更改为3x2

labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg = fig.legend(pivot.columns.get_level_values('Year'), labels=labels, ncol=3,bbox_to_anchor=(.65,0)) ## ncol因为我使用的是较旧版本的matplotlib

plt.show()

自定义颜色的matplotlib线条,但图例不会更新。

英文:

As you have not provided any data, I took the healthexp dataset because it has 6 countries and year. Post that, converted it to what I assume is the format you have by changing each country to a column and plotted it. As there are just 6 countries in this dataset, I have changed you 4x3 plot to 3x2. Please check if this works for you.

fig, axs = plt.subplots(2, 3, figsize=(12,5))

## My data
df=sns.load_dataset('healthexp')
df['Month'] = np.random.randint(1,13, size=(len(df)))  ## Create random month
df['Year']= pd.to_datetime(df['Year'].astype(str)+df['Month'].astype(str), format='%Y%m') #Change year to datetime
df=df.pivot_table(index="Year", columns="Country", values="Spending_USD", fill_value=0) ## Change each Country as column
countries=df.columns  ## Guessing you want one plot per country

for i, country in enumerate(countries):
    pivot = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month)
    cm = plt.get_cmap('gist_rainbow')
    cNorm = colors.Normalize(vmin=0, vmax=12)
    scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
    axs[i%2,int(i/2)].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)])  ## Note I changed axs to 3x2
    axs[i%2,int(i/2)].plot(pivot)  ## Note I changed axs to 3x2

labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg = fig.legend(pivot.columns.get_level_values('Year'), labels=labels, ncol=3,bbox_to_anchor=(.65,0)) ##ncol as I om on older version of matplotlib

plt.show()

自定义颜色的matplotlib线条,但图例不会更新。

huangapple
  • 本文由 发表于 2023年6月22日 15:56:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/76529699.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定