英文:
How to remove values under a certain threshold but keep the coloring in a heat map?
问题
I made a heat map function based on seaborn
that plots the heat map as I want it to be plotted. The only problem is that I also want to hide the values from the heat map in case they are less than 5 but I want to keep the coloring. I managed to do it with this
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
data_norm = data.div(data.sum(axis=1), axis=0) * 100
annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
# annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")
def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
if not isinstance(data, pd.DataFrame):
raise ValueError("data argument must be a pandas DataFrame")
color = kwargs.get("color", ["w", "r"])
cmap = mplc.LinearSegmentedColormap.from_list("", color)
fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
cbar.outline.set_edgecolor("k")
cbar.outline.set_linewidth(1)
cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
for spine in plt.gca().spines.values():
spine.set_visible(True)
spine.set_edgecolor("k")
spine.set_linewidth(1)
ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
savename = kwargs.get("savename", None)
if savename is not None:
fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
heatmaplotter(data_norm, annot=annot_data)
but the problem is that now `nan` is shown in the cells that are less than 5 (Figure 1).
Figure 1: `nan` are present.
I want them to be empty. I tried with
```python
annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")
which alas gives ValueError: Unknown format code 'f' for object of type 'str'
. I also tried
ax = sns.heatmap(data.mask(data <= 5), cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
but that removes the colors as well which I want to keep (Figure 2).
Figure 2: No coloring in the cells that are less than 5.
Is there a way to mask the values that are less than 5 but still keep the coloring?
英文:
I made a heat map function based on seaborn
that plots the heat map as I want it to be plotted. The only problem is that I also want to hide the values from the heat map in case they are less than 5 but I want to keep the coloring. I managed to to do it with this
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
data_norm = data.div(data.sum(axis=1), axis=0) * 100
annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
# annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")
def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
if not isinstance(data, pd.DataFrame):
raise ValueError("data argument must be a pandas DataFrame")
color = kwargs.get("color", ["w", "r"])
cmap = mplc.LinearSegmentedColormap.from_list("", color)
fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
cbar.outline.set_edgecolor("k")
cbar.outline.set_linewidth(1)
cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
for spine in plt.gca().spines.values():
spine.set_visible(True)
spine.set_edgecolor("k")
spine.set_linewidth(1)
ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
savename = kwargs.get("savename", None)
if savename is not None:
fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
heatmaplotter(data_norm, annot=annot_data)
but the problem is that now nan
is shown in the cells that are less than 5 (Figure 1).
I want them to be empty. I tried with
annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")
which alas gives ValueError: Unknown format code 'f' for object of type 'str'
. I also tried
ax = sns.heatmap(data.mask(data <= 5), cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
but that removes the colors as well which I want to keep (Figure 2).
Figure 2: No coloring in the cells that are less than 5.
Is there a way to mask the values that are less than 5 but still keep the coloring?
答案1
得分: 2
以下是代码的翻译部分:
部分1:
获取不需要注释的数据
show_annot_array = np.array(data_norm > 5)
部分2:
隐藏注释
for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
text.set_visible(show_annot)
最终代码
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
data_norm = data.div(data.sum(axis=1), axis=0) * 100
################ 部分1 ################
show_annot_array = np.array(data_norm > 5)
########################################
annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
if not isinstance(data, pd.DataFrame):
raise ValueError("data argument must be a pandas DataFrame")
color = kwargs.get("color", ["w", "r"])
cmap = mplc.LinearSegmentedColormap.from_list("", color)
fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
cbar.outline.set_edgecolor("k")
cbar.outline.set_linewidth(1)
cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
for spine in plt.gca().spines.values():
spine.set_visible(True)
spine.set_edgecolor("k")
spine.set_linewidth(1)
ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
savename = kwargs.get("savename", None)
if savename is not None:
fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
################ 部分2 ################
for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
text.set_visible(show_annot)
########################################
heatmaplotter(data_norm, annot=annot_data)
希望这有助于你理解代码的翻译。
英文:
you need to add these code:
Part 1:
get data that you don't need annot
show_annot_array = np.array(data_norm > 5)
Part 2:
hide annot
for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
text.set_visible(show_annot)
Final code
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
data_norm = data.div(data.sum(axis=1), axis=0) * 100
################ Part 1 ################
show_annot_array = np.array(data_norm > 5)
########################################
annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
# annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")
def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
if not isinstance(data, pd.DataFrame):
raise ValueError("data argument must be a pandas DataFrame")
color = kwargs.get("color", ["w", "r"])
cmap = mplc.LinearSegmentedColormap.from_list("", color)
fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
cbar.outline.set_edgecolor("k")
cbar.outline.set_linewidth(1)
cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
for spine in plt.gca().spines.values():
spine.set_visible(True)
spine.set_edgecolor("k")
spine.set_linewidth(1)
ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
savename = kwargs.get("savename", None)
if savename is not None:
fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
################ Part 2 ################
for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
text.set_visible(show_annot)
########################################
heatmaplotter(data_norm, annot=annot_data)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论