如何删除低于某个阈值的值,但保留热图中的颜色?

huangapple go评论67阅读模式
英文:

How to remove values under a certain threshold but keep the coloring in a heat map?

问题

I made a heat map function based on seaborn that plots the heat map as I want it to be plotted. The only problem is that I also want to hide the values from the heat map in case they are less than 5 but I want to keep the coloring. I managed to do it with this

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc

np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
data_norm = data.div(data.sum(axis=1), axis=0) * 100

annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
# annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")

def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
    if not isinstance(data, pd.DataFrame):
        raise ValueError("data argument must be a pandas DataFrame")
    color = kwargs.get("color", ["w", "r"])
    cmap = mplc.LinearSegmentedColormap.from_list("", color)
    fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
    ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
    ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
    plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
    plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
    cbar.outline.set_edgecolor("k")
    cbar.outline.set_linewidth(1)
    cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
    for spine in plt.gca().spines.values():
        spine.set_visible(True)
        spine.set_edgecolor("k")
        spine.set_linewidth(1)
    ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
    ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
    savename = kwargs.get("savename", None)
    if savename is not None:
        fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))

heatmaplotter(data_norm, annot=annot_data)

but the problem is that now `nan` is shown in the cells that are less than 5 (Figure 1).

Figure 1: `nan` are present.

I want them to be empty. I tried with

```python
annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")

which alas gives ValueError: Unknown format code 'f' for object of type 'str'. I also tried

ax = sns.heatmap(data.mask(data <= 5), cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})

but that removes the colors as well which I want to keep (Figure 2).

Figure 2: No coloring in the cells that are less than 5.

Is there a way to mask the values that are less than 5 but still keep the coloring?

英文:

I made a heat map function based on seaborn that plots the heat map as I want it to be plotted. The only problem is that I also want to hide the values from the heat map in case they are less than 5 but I want to keep the coloring. I managed to to do it with this

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=[&quot;A&quot;, &quot;B&quot;, &quot;C&quot;, &quot;D&quot;, &quot;E&quot;])
data_norm = data.div(data.sum(axis=1), axis=0) * 100
annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else None)
# annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else &quot;&quot;)
def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -&gt; None:
if not isinstance(data, pd.DataFrame):
raise ValueError(&quot;data argument must be a pandas DataFrame&quot;)
color = kwargs.get(&quot;color&quot;, [&quot;w&quot;, &quot;r&quot;])
cmap = mplc.LinearSegmentedColormap.from_list(&quot;&quot;, color)
fig, ax = plt.subplots(figsize=kwargs.get(&quot;figsize&quot;, (16.56/2.54, 8.83/2.54)))
ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get(&quot;annot&quot;, True), linewidth=kwargs.get(&quot;linewidth&quot;, None), linecolor=kwargs.get(&quot;linecolor&quot;, None), fmt=kwargs.get(&quot;fmt&quot;, &quot;.0f&quot;), vmin=vmin, vmax=vmax, cbar_kws={&quot;ticks&quot;: kwargs.get(&quot;ticks&quot;, [0, 25, 50, 75, 100])})
ax.tick_params(axis=&quot;both&quot;, which=&quot;both&quot;, direction=&quot;in&quot;, top=True, bottom=True, left=True, right=True)
plt.xticks(rotation=kwargs.get(&quot;xrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
plt.yticks(rotation=kwargs.get(&quot;yrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=kwargs.get(&quot;labelsize&quot;, 10), direction=&quot;in&quot;, length=6, width=1, color=&quot;k&quot;, pad=5)
cbar.outline.set_edgecolor(&quot;k&quot;)
cbar.outline.set_linewidth(1)
cbar.set_label(kwargs.get(&quot;cb_label&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
for spine in plt.gca().spines.values():
spine.set_visible(True)
spine.set_edgecolor(&quot;k&quot;)
spine.set_linewidth(1)
ax.set_xlabel(kwargs.get(&quot;xlabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
ax.set_ylabel(kwargs.get(&quot;ylabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
savename = kwargs.get(&quot;savename&quot;, None)
if savename is not None:
fig.savefig(f&quot;{savename}.{kwargs.get(&#39;format&#39;, &#39;svg&#39;)}&quot;, bbox_inches=&quot;tight&quot;, dpi=kwargs.get(&quot;dpi&quot;, None))
heatmaplotter(data_norm, annot=annot_data)

but the problem is that now nan is shown in the cells that are less than 5 (Figure 1).

如何删除低于某个阈值的值,但保留热图中的颜色?
Figure 1: nan are present.

I want them to be empty. I tried with

annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else &quot;&quot;)

which alas gives ValueError: Unknown format code &#39;f&#39; for object of type &#39;str&#39;. I also tried

ax = sns.heatmap(data.mask(data &lt;= 5), cmap=cmap, annot=kwargs.get(&quot;annot&quot;, True), linewidth=kwargs.get(&quot;linewidth&quot;, None), linecolor=kwargs.get(&quot;linecolor&quot;, None), fmt=kwargs.get(&quot;fmt&quot;, &quot;.0f&quot;), vmin=vmin, vmax=vmax, cbar_kws={&quot;ticks&quot;: kwargs.get(&quot;ticks&quot;, [0, 25, 50, 75, 100])})

but that removes the colors as well which I want to keep (Figure 2).

如何删除低于某个阈值的值,但保留热图中的颜色?
Figure 2: No coloring in the cells that are less than 5.

Is there a way to mask the values that are less than 5 but still keep the coloring?

答案1

得分: 2

以下是代码的翻译部分:

部分1:

获取不需要注释的数据

show_annot_array = np.array(data_norm > 5)

部分2:

隐藏注释

for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
    text.set_visible(show_annot)

最终代码

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc

np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
data_norm = data.div(data.sum(axis=1), axis=0) * 100

################ 部分1 ################
show_annot_array = np.array(data_norm > 5)
########################################

annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)

def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
    if not isinstance(data, pd.DataFrame):
        raise ValueError("data argument must be a pandas DataFrame")
    color = kwargs.get("color", ["w", "r"])
    cmap = mplc.LinearSegmentedColormap.from_list("", color)
    fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
    ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
    ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
    plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
    plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
    cbar.outline.set_edgecolor("k")
    cbar.outline.set_linewidth(1)
    cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
    for spine in plt.gca().spines.values():
        spine.set_visible(True)
        spine.set_edgecolor("k")
        spine.set_linewidth(1)
    ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
    ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
    savename = kwargs.get("savename", None)
    if savename is not None:
        fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
        
    ################ 部分2 ################
    for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
        text.set_visible(show_annot)
    ########################################

heatmaplotter(data_norm, annot=annot_data)

希望这有助于你理解代码的翻译。

英文:

you need to add these code:

Part 1:

get data that you don't need annot

show_annot_array = np.array(data_norm &gt; 5)

Part 2:

hide annot

for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
text.set_visible(show_annot)

Final code

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
np.random.seed(1234)
data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=[&quot;A&quot;, &quot;B&quot;, &quot;C&quot;, &quot;D&quot;, &quot;E&quot;])
data_norm = data.div(data.sum(axis=1), axis=0) * 100
################ Part 1 ################
show_annot_array = np.array(data_norm &gt; 5)
########################################
annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else None)
# annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else &quot;&quot;)
def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -&gt; None:
if not isinstance(data, pd.DataFrame):
raise ValueError(&quot;data argument must be a pandas DataFrame&quot;)
color = kwargs.get(&quot;color&quot;, [&quot;w&quot;, &quot;r&quot;])
cmap = mplc.LinearSegmentedColormap.from_list(&quot;&quot;, color)
fig, ax = plt.subplots(figsize=kwargs.get(&quot;figsize&quot;, (16.56/2.54, 8.83/2.54)))
ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get(&quot;annot&quot;, True), linewidth=kwargs.get(&quot;linewidth&quot;, None), linecolor=kwargs.get(&quot;linecolor&quot;, None), fmt=kwargs.get(&quot;fmt&quot;, &quot;.0f&quot;), vmin=vmin, vmax=vmax, cbar_kws={&quot;ticks&quot;: kwargs.get(&quot;ticks&quot;, [0, 25, 50, 75, 100])})
ax.tick_params(axis=&quot;both&quot;, which=&quot;both&quot;, direction=&quot;in&quot;, top=True, bottom=True, left=True, right=True)
plt.xticks(rotation=kwargs.get(&quot;xrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
plt.yticks(rotation=kwargs.get(&quot;yrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=kwargs.get(&quot;labelsize&quot;, 10), direction=&quot;in&quot;, length=6, width=1, color=&quot;k&quot;, pad=5)
cbar.outline.set_edgecolor(&quot;k&quot;)
cbar.outline.set_linewidth(1)
cbar.set_label(kwargs.get(&quot;cb_label&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
for spine in plt.gca().spines.values():
spine.set_visible(True)
spine.set_edgecolor(&quot;k&quot;)
spine.set_linewidth(1)
ax.set_xlabel(kwargs.get(&quot;xlabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
ax.set_ylabel(kwargs.get(&quot;ylabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
savename = kwargs.get(&quot;savename&quot;, None)
if savename is not None:
fig.savefig(f&quot;{savename}.{kwargs.get(&#39;format&#39;, &#39;svg&#39;)}&quot;, bbox_inches=&quot;tight&quot;, dpi=kwargs.get(&quot;dpi&quot;, None))
################ Part 2 ################
for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
text.set_visible(show_annot)
########################################
heatmaplotter(data_norm, annot=annot_data)

如何删除低于某个阈值的值,但保留热图中的颜色?

huangapple
  • 本文由 发表于 2023年4月17日 15:34:35
  • 转载请务必保留本文链接:https://go.coder-hub.com/76032686.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定