英文:
Custom colored lines in matplotlib but legend won't update
问题
我正尝试从一个Pandas数据框中绘制线条,这个数据框是我的数据的一个透视表。有12条线,每个月一条,我不想有重复的颜色。因此,我决定手动设置颜色。在绘图中,线条的颜色已经更新,但当我创建图例时,它仍然显示旧的颜色。
以下是我的代码片段。我想更改12个子图中的4个子图的颜色:
fig, axs = plt.subplots(4, 3)
for i, country in enumerate(countries):
pivot_yearly = pd.pivot_table(df[country].to_frame(), index=df.index.month, columns=df.index.year)
pivot_monthly = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month) # 问题出在这里。
pivot_dyyr = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.year)
axs[i,0].plot(pivot_yearly)
cm = plt.get_cmap('gist_rainbow')
cNorm = colors.Normalize(vmin=0, vmax=11)
scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
axs[i,1].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)])
axs[i,1].plot(pivot_monthly)
axs[i,3].plot(pivot_dyyr)
labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg1 = fig.legend(pivot_yearly.columns.get_level_values('Date'), ncols=3)
leg2 = fig.legend(pivot_monthly.columns.get_level_values('Date'), labels=labels, ncols=3)
leg3 = fig.legend(pivot_dyyr.columns.get_level_values('Date'), ncols=3)
plt.tight_layout()
plt.show()
不幸的是,我无法分享我的数据,但它涉及4个国家和索引'Date'作为DateTimeIndex。
当我运行axs[i,1].set_prop_cycle()
这一行时,线条颜色已更新,但图例没有更新。
英文:
I am trying to plot lines out of a pandas dataframe, a pivot table of my data. There's 12 lines, one for each month of the year, and I don't want any duplicate colors. So I decided to set the colors manually. In the plot, the color of the lines is updated, but when I create a legend, it still displays the old colors.
Below is a snippet of my code. I want to change the colors for 4 out of 12 subplots:
fig, axs = plt.subplots(4, 3)
for i, country in enumerate(countries):
pivot_yearly = pd.pivot_table(df[country].to_frame(), index=df.index.month, columns=df.index.year)
pivot_monthly = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month) # The problem is with this one.
pivot_dyyr = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.year)
axs[i,0].plot(pivot_yearly)
cm = plt.get_cmap('gist_rainbow')
cNorm = colors.Normalize(vmin=0, vmax=11)
scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
axs[i,1].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)])
axs[i,1].plot(pivot_monthly)
axs[i,3].plot(pivot_dyyr)
labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg1 = fig.legend(pivot_yearly.columsn.get_level_values('Date'), ncols=3)
leg2 = fig.legend(pivot_monthly.columns.get_level_values('Date'), labels=labels, ncols=3)
leg3 = fig.legend(pivot_dyyr.columns.get_level_values('Date'), ncols=3)
plt.tight_layout()
plt.show()
Unfortunately I cannot share my data, but it concerns 4 countries and the index 'Date' as a DateTimeIndex.
When I run the line axs[i,1].set_prop_cycle()
the line colors are updated, but the legend is not.
答案1
得分: 1
由于您没有提供任何数据,我使用了healthexp
数据集,因为它包含了6个国家和年份的数据。然后,我将其转换为我认为符合您要求的格式,将每个国家转换为一列并绘制了图表。由于此数据集中仅有6个国家,我将您的4x3图表更改为3x2。请检查是否符合您的要求。
fig, axs = plt.subplots(2, 3, figsize=(12,5))
## 我的数据
df=sns.load_dataset('healthexp')
df['Month'] = np.random.randint(1,13, size=(len(df))) ## 创建随机月份
df['Year']= pd.to_datetime(df['Year'].astype(str)+df['Month'].astype(str), format='%Y%m') # 将年份更改为日期时间格式
df=df.pivot_table(index="Year", columns="Country", values="Spending_USD", fill_value=0) ## 将每个国家更改为列
countries=df.columns ## 猜测您想要每个国家一个图
for i, country in enumerate(countries):
pivot = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month)
cm = plt.get_cmap('gist_rainbow')
cNorm = colors.Normalize(vmin=0, vmax=12)
scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
axs[i%2,int(i/2)].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)]) ## 注意我将axs更改为3x2
axs[i%2,int(i/2)].plot(pivot) ## 注意我将axs更改为3x2
labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg = fig.legend(pivot.columns.get_level_values('Year'), labels=labels, ncol=3,bbox_to_anchor=(.65,0)) ## ncol因为我使用的是较旧版本的matplotlib
plt.show()
英文:
As you have not provided any data, I took the healthexp
dataset because it has 6 countries and year. Post that, converted it to what I assume is the format you have by changing each country to a column and plotted it. As there are just 6 countries in this dataset, I have changed you 4x3 plot to 3x2. Please check if this works for you.
fig, axs = plt.subplots(2, 3, figsize=(12,5))
## My data
df=sns.load_dataset('healthexp')
df['Month'] = np.random.randint(1,13, size=(len(df))) ## Create random month
df['Year']= pd.to_datetime(df['Year'].astype(str)+df['Month'].astype(str), format='%Y%m') #Change year to datetime
df=df.pivot_table(index="Year", columns="Country", values="Spending_USD", fill_value=0) ## Change each Country as column
countries=df.columns ## Guessing you want one plot per country
for i, country in enumerate(countries):
pivot = pd.pivot_table(df[country].to_frame(), index=df.index.weekday, columns=df.index.month)
cm = plt.get_cmap('gist_rainbow')
cNorm = colors.Normalize(vmin=0, vmax=12)
scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
axs[i%2,int(i/2)].set_prop_cycle(color=[scalarMap.to_rgba(j) for j in range(12)]) ## Note I changed axs to 3x2
axs[i%2,int(i/2)].plot(pivot) ## Note I changed axs to 3x2
labels = "Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()
leg = fig.legend(pivot.columns.get_level_values('Year'), labels=labels, ncol=3,bbox_to_anchor=(.65,0)) ##ncol as I om on older version of matplotlib
plt.show()
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论