英文:
Shading regions inside an mplfinance chart
问题
I am using matplotlib v 3.7.0, mplfinance version '0.12.9b7', and Python 3.10.
I am trying to shade regions of a plot, and although my logic seems correct, the shaded areas are not being displayed on the plot.
This is my code:
import yfinance as yf
import mplfinance as mpf
import pandas as pd
# Download the stock data
df = yf.download('TSLA', start='2022-01-01', end='2022-03-31')
# Define the date ranges for shading
red_range = ['2022-01-15', '2022-02-15']
blue_range = ['2022-03-01', '2022-03-15']
# Create a function to shade the chart regions
def shade_region(ax, region_dates, color):
region_dates.sort()
start_date = region_dates[0]
end_date = region_dates[1]
# plot vertical lines
ax.axvline(pd.to_datetime(start_date), color=color, linestyle='--')
ax.axvline(pd.to_datetime(end_date), color=color, linestyle='--')
# create fill
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
ax.fill_between(pd.date_range(start=start_date, end=end_date), ymin, ymax, alpha=0.2, color=color)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
# Plot the candlestick chart with volume
fig, axlist = mpf.plot(df, type='candle', volume=True, style='charles',
title='TSLA Stock Price', ylabel='Price ($)', ylabel_lower='Shares\nTraded',
figratio=(2,1), figsize=(10,5), tight_layout=True, returnfig=True)
# Get the current axis object
ax = axlist[0]
# Shade the regions on the chart
shade_region(ax, red_range, 'red')
shade_region(ax, blue_range, 'blue')
# Show the plot
mpf.show()
Why are the selected regions not being shaded, and how do I fix this?
英文:
I am using matplotlib v 3.7.0, mplfinance version '0.12.9b7', and Python 3.10.
I am trying to shade regions of a plot, and although my logic seems correct, the shaded areas are not being displayed on the plot.
This is my code:
import yfinance as yf
import mplfinance as mpf
import pandas as pd
# Download the stock data
df = yf.download('TSLA', start='2022-01-01', end='2022-03-31')
# Define the date ranges for shading
red_range = ['2022-01-15', '2022-02-15']
blue_range = ['2022-03-01', '2022-03-15']
# Create a function to shade the chart regions
def shade_region(ax, region_dates, color):
region_dates.sort()
start_date = region_dates[0]
end_date = region_dates[1]
# plot vertical lines
ax.axvline(pd.to_datetime(start_date), color=color, linestyle='--')
ax.axvline(pd.to_datetime(end_date), color=color, linestyle='--')
# create fill
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
ax.fill_between(pd.date_range(start=start_date, end=end_date), ymin, ymax, alpha=0.2, color=color)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
# Plot the candlestick chart with volume
fig, axlist = mpf.plot(df, type='candle', volume=True, style='charles',
title='TSLA Stock Price', ylabel='Price ($)', ylabel_lower='Shares\nTraded',
figratio=(2,1), figsize=(10,5), tight_layout=True, returnfig=True)
# Get the current axis object
ax = axlist[0]
# Shade the regions on the chart
shade_region(ax, red_range, 'red')
shade_region(ax, blue_range, 'blue')
# Show the plot
mpf.show()
Why are the selected regions not being shaded, and how do I fix this?
答案1
得分: 4
问题在于,当show_nontrading=False
(这是默认值,如果未指定)时,X轴不是日期,这与您的期望不符。因此,您通过日期指定的垂直线和fill_between
实际上偏离了图表。
最简单的解决方案是将show_nontrading=True
。使用您的代码:
fig, axlist = mpf.plot(df, type='candle', volume=True, style='charles',
title='TSLA Stock Price', ylabel='Price ($)',
ylabel_lower='Shares\nTraded', figratio=(2,1),
figsize=(10,5), tight_layout=True, returnfig=True,
show_nontrading=True)
# 获取当前的坐标轴对象
ax = axlist[0]
# 给图表着色
shade_region(ax, red_range, 'red')
shade_region(ax, blue_range, 'blue')
# 显示图表
mpf.show()
有两种其他解决方案,允许您保留show_nontrading=False
,如果这是您的首选。
- 第一个解决方案是使用
mplfinance
的关键字参数,并且不使用returnfig
。以下是文档链接:
这是首选解决方案,因为让mplfinance
处理所有Axes对象的操作通常是一个好主意,除非有些事情无法以其他方式实现。
以下是修改您的代码的示例:
red_range = ['2022-01-15', '2022-02-15']
blue_range = ['2022-03-01', '2022-03-15']
vline_dates = red_range + blue_range
vline_colors = ['red','red','blue','blue']
vline_dict = dict(vlines=vline_dates,colors=vline_colors,line_style='--')
ymax = max(df['High'].values)
ymin = min(df['Low'].values)
# 从日期时间索引创建一个数据框以在生成fill_between `where`值时使用:
dfdates = df.index.to_frame()
# 生成红色布尔`where`值:
where_values = pd.notnull(dfdates[(dfdates >= red_range[0]) & (dfdates <= red_range[1])].Date.values)
# 组装红色fill_between规范:
fb_red = dict(y1=ymin, y2=ymax, where=where_values, alpha=0.2, color='red')
# 生成蓝色布尔`where`值:
where_values = pd.notnull(dfdates[(dfdates >= blue_range[0]) & (dfdates <= blue_range[1])].Date.values)
# 组装蓝色fill_between规范:
fb_blue = dict(y1=ymin, y2=ymax, where=where_values, alpha=0.2, color='blue')
# 绘制带有成交量的蜡烛图
mpf.plot(df, type='candle', volume=True, style='charles',
title='TSLA Stock Price', ylabel='Price ($)', label_lower='Shares\nTraded',
figratio=(2,1), figsize=(10,5), tight_layout=True,
vlines=vline_dict,
fill_between=[fb_red, fb_blue])
请注意,最左边的垂直线和阴影区之间有一些间隙。这是因为您选择的日期('2022-01-15')在周末(而且是一个3天的周末)。如果将日期更改为'2022-01-14'或'2022-01-18',它将正常工作,如下所示:
- 最后的解决方案需要
returnfig=True
。这不是推荐的解决方案,但它确实有效。
首先,重要的是了解以下事项:当show_nontrading
未指定时,默认为False
,这意味着尽管您在X轴上看到日期,但实际值是数据帧的行号。有关更详细的解释,请点击此处。
因此,在您的代码中,您指定的不是日期,而是该日期出现的行号。
指定行号的最简单方法是使用函数date_to_iloc(df.index.to_series(), date)
,如下所示:
def date_to_iloc(dtseries, date):
'''
Convert a `date` to a location, given a date series w/a datetime index.
If `date` does not exactly match a date in the series then interpolate between two dates.
If `date` is outside the range of dates in the series, then raise an exception.
'''
d1s = dtseries.loc[date:]
if len(d1s) < 1:
sdtrange = str(dtseries[0]) + ' to ' + str(dtseries[-1])
raise ValueError('User specified line date "' + str(date) +
'" is beyond (greater than) range of plotted data (' + sdtrange + ').')
d1 = d1s.index[0]
d2s = dtseries.loc[:date]
if len(d2s) < 1:
sdtrange = str(dtseries[0]) + ' to ' + str(dtseries[-1])
raise ValueError('User specified line date "' + str(date) +
'" is before (less than) range of plotted data (' + sdtrange + ').')
d2 = dtseries.loc[:date].index[-1]
loc1 = dtseries.index.get_loc(d1)
if isinstance(loc1, slice):
loc1 = loc1.start
loc2 = dtseries.index.get_loc(d2)
if isinstance(loc2, slice):
loc2 = loc2.stop - 1
return (loc1 + loc2) / 2.0
该函数以数据帧索引转换为系列的形式作为输入。因此,以下是修改您的代码以使用这种方法的示例:
# 定义阴影的日期范围
red_range = [date_to_iloc(df
<details>
<summary>英文:</summary>
The problem is that, when `show_nontrading=False` (which is the default value when not specified) then the X-axis are *not* dates as you would expect. Thus the vertical lines and the fill_between that you are specifying **by date** are actually ending up way off the chart.
The simplest solution is to set `show_nontrading=True`. Using your code:
```python
fig, axlist = mpf.plot(df, type='candle', volume=True, style='charles',
title='TSLA Stock Price', ylabel='Price ($)',
ylabel_lower='Shares\nTraded', figratio=(2,1),
figsize=(10,5), tight_layout=True, returnfig=True,
show_nontrading=True)
# Get the current axis object
ax = axlist[0]
# Shade the regions on the chart
shade_region(ax, red_range, 'red')
shade_region(ax, blue_range, 'blue')
# Show the plot
mpf.show()
There are two other solutions, to the problem, that allow you to leave show_nontrading=False
if that is your preference.
1. The first solution is to use mplfinance's kwargs, and do not use returnfig. Here is the documentation:
This is the prefered solution since it is always a good idea to let mplfinance do all manipulation of Axes objects, unless there is something that you cannot accomplish otherwise.
And here is an example modifying your code:
red_range = ['2022-01-15', '2022-02-15']
blue_range = ['2022-03-01', '2022-03-15']
vline_dates = red_range + blue_range
vline_colors = ['red','red','blue','blue']
vline_dict = dict(vlines=vline_dates,colors=vline_colors,line_style='--')
ymax = max(df['High'].values)
ymin = min(df['Low'].values)
# create a dataframe from the datetime index
# for using in generating the fill_between `where` values:
dfdates = df.index.to_frame()
# generate red boolean where values:
where_values = pd.notnull(dfdates[(dfdates >= red_range[0]) & (dfdates <= red_range[1])].Date.values)
# put together the red fill_between specification:
fb_red = dict(y1=ymin,y2=ymax,where=where_values,alpha=0.2,color='red')
# generate blue boolean where values:
where_values = pd.notnull(dfdates[(dfdates >= blue_range[0]) & (dfdates <= blue_range[1])].Date.values)
# put together the red fill_between specification:
fb_blue = dict(y1=ymin,y2=ymax,where=where_values,alpha=0.2,color='blue')
# Plot the candlestick chart with volume
mpf.plot(df, type='candle', volume=True, style='charles',
title='TSLA Stock Price', ylabel='Price ($)', label_lower='Shares\nTraded',
figratio=(2,1), figsize=(10,5), tight_layout=True,
vlines=vline_dict,
fill_between=[fb_red,fb_blue])
Notice there is a slight space between the left most vertical line and the shaded area. That is because the date you have selected ('2022-01-15') is on the weekend (and a 3-day weekend at that). If you change the date to '2022-01-14' or '2022-01-18' it will work fine, as shown here:
2. The last solution requires returnfig=True
. This is not the recommended solution but it does work.
First, it's important to understand the following: When show_nontrading
is not specified, it defaults to False
, which means that, although you see datetimes displayed on the x-axis, the actual values are the row number of your dataframe.
Click here for a more detailed
explanation.
Therefore,in your code, instead of specifying dates, you specify the row number where that date appears.
The simplest way to specify the row number is to use function date_to_iloc(df.index.to_series(),date)
as defined here:
def date_to_iloc(dtseries,date):
'''Convert a `date` to a location, given a date series w/a datetime index.
If `date` does not exactly match a date in the series then interpolate between two dates.
If `date` is outside the range of dates in the series, then raise an exception
.
'''
d1s = dtseries.loc[date:]
if len(d1s) < 1:
sdtrange = str(dtseries[0])+' to '+str(dtseries[-1])
raise ValueError('User specified line date "'+str(date)+
'" is beyond (greater than) range of plotted data ('+sdtrange+').')
d1 = d1s.index[0]
d2s = dtseries.loc[:date]
if len(d2s) < 1:
sdtrange = str(dtseries[0])+' to '+str(dtseries[-1])
raise ValueError('User specified line date "'+str(date)+
'" is before (less than) range of plotted data ('+sdtrange+').')
d2 = dtseries.loc[:date].index[-1]
# If there are duplicate dates in the series, for example in a renko plot
# then .get_loc(date) will return a slice containing all the dups, so:
loc1 = dtseries.index.get_loc(d1)
if isinstance(loc1,slice): loc1 = loc1.start
loc2 = dtseries.index.get_loc(d2)
if isinstance(loc2,slice): loc2 = loc2.stop - 1
return (loc1+loc2)/2.0
The function takes as input the data frame index converted to a series. So the following changes to your code will allow it to work using this method:
# Define the date ranges for shading
red_range = [date_to_iloc(df.index.to_series(),dt) for dt in ['2022-01-15', '2022-02-15']]
blue_range = [date_to_iloc(df.index.to_series(),dt) for dt in ['2022-03-01', '2022-03-15']]
...
ax.axvline(start_date, color=color, linestyle='--')
ax.axvline(end_date, color=color, linestyle='--')
...
ax.fill_between([start_date,end_date], ymin, ymax, alpha=0.2, color=color)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论