如何删除低于某个阈值的值,但保留热图中的颜色?

huangapple go评论87阅读模式
英文:

How to remove values under a certain threshold but keep the coloring in a heat map?

问题

I made a heat map function based on seaborn that plots the heat map as I want it to be plotted. The only problem is that I also want to hide the values from the heat map in case they are less than 5 but I want to keep the coloring. I managed to do it with this

  1. import pandas as pd
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import matplotlib.colors as mplc
  6. np.random.seed(1234)
  7. data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
  8. data_norm = data.div(data.sum(axis=1), axis=0) * 100
  9. annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
  10. # annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")
  11. def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
  12. if not isinstance(data, pd.DataFrame):
  13. raise ValueError("data argument must be a pandas DataFrame")
  14. color = kwargs.get("color", ["w", "r"])
  15. cmap = mplc.LinearSegmentedColormap.from_list("", color)
  16. fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
  17. ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
  18. ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
  19. plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
  20. plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
  21. cbar = ax.collections[0].colorbar
  22. cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
  23. cbar.outline.set_edgecolor("k")
  24. cbar.outline.set_linewidth(1)
  25. cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
  26. for spine in plt.gca().spines.values():
  27. spine.set_visible(True)
  28. spine.set_edgecolor("k")
  29. spine.set_linewidth(1)
  30. ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
  31. ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
  32. savename = kwargs.get("savename", None)
  33. if savename is not None:
  34. fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
  35. heatmaplotter(data_norm, annot=annot_data)
  36. but the problem is that now `nan` is shown in the cells that are less than 5 (Figure 1).
  37. Figure 1: `nan` are present.
  38. I want them to be empty. I tried with
  39. ```python
  40. annot_data = data_norm.applymap(lambda x: x if x >= 5 else "")

which alas gives ValueError: Unknown format code 'f' for object of type 'str'. I also tried

  1. ax = sns.heatmap(data.mask(data <= 5), cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})

but that removes the colors as well which I want to keep (Figure 2).

Figure 2: No coloring in the cells that are less than 5.

Is there a way to mask the values that are less than 5 but still keep the coloring?

英文:

I made a heat map function based on seaborn that plots the heat map as I want it to be plotted. The only problem is that I also want to hide the values from the heat map in case they are less than 5 but I want to keep the coloring. I managed to to do it with this

  1. import pandas as pd
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import matplotlib.colors as mplc
  6. np.random.seed(1234)
  7. data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=[&quot;A&quot;, &quot;B&quot;, &quot;C&quot;, &quot;D&quot;, &quot;E&quot;])
  8. data_norm = data.div(data.sum(axis=1), axis=0) * 100
  9. annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else None)
  10. # annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else &quot;&quot;)
  11. def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -&gt; None:
  12. if not isinstance(data, pd.DataFrame):
  13. raise ValueError(&quot;data argument must be a pandas DataFrame&quot;)
  14. color = kwargs.get(&quot;color&quot;, [&quot;w&quot;, &quot;r&quot;])
  15. cmap = mplc.LinearSegmentedColormap.from_list(&quot;&quot;, color)
  16. fig, ax = plt.subplots(figsize=kwargs.get(&quot;figsize&quot;, (16.56/2.54, 8.83/2.54)))
  17. ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get(&quot;annot&quot;, True), linewidth=kwargs.get(&quot;linewidth&quot;, None), linecolor=kwargs.get(&quot;linecolor&quot;, None), fmt=kwargs.get(&quot;fmt&quot;, &quot;.0f&quot;), vmin=vmin, vmax=vmax, cbar_kws={&quot;ticks&quot;: kwargs.get(&quot;ticks&quot;, [0, 25, 50, 75, 100])})
  18. ax.tick_params(axis=&quot;both&quot;, which=&quot;both&quot;, direction=&quot;in&quot;, top=True, bottom=True, left=True, right=True)
  19. plt.xticks(rotation=kwargs.get(&quot;xrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  20. plt.yticks(rotation=kwargs.get(&quot;yrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  21. cbar = ax.collections[0].colorbar
  22. cbar.ax.tick_params(labelsize=kwargs.get(&quot;labelsize&quot;, 10), direction=&quot;in&quot;, length=6, width=1, color=&quot;k&quot;, pad=5)
  23. cbar.outline.set_edgecolor(&quot;k&quot;)
  24. cbar.outline.set_linewidth(1)
  25. cbar.set_label(kwargs.get(&quot;cb_label&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  26. for spine in plt.gca().spines.values():
  27. spine.set_visible(True)
  28. spine.set_edgecolor(&quot;k&quot;)
  29. spine.set_linewidth(1)
  30. ax.set_xlabel(kwargs.get(&quot;xlabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  31. ax.set_ylabel(kwargs.get(&quot;ylabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  32. savename = kwargs.get(&quot;savename&quot;, None)
  33. if savename is not None:
  34. fig.savefig(f&quot;{savename}.{kwargs.get(&#39;format&#39;, &#39;svg&#39;)}&quot;, bbox_inches=&quot;tight&quot;, dpi=kwargs.get(&quot;dpi&quot;, None))
  35. heatmaplotter(data_norm, annot=annot_data)

but the problem is that now nan is shown in the cells that are less than 5 (Figure 1).

如何删除低于某个阈值的值,但保留热图中的颜色?
Figure 1: nan are present.

I want them to be empty. I tried with

  1. annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else &quot;&quot;)

which alas gives ValueError: Unknown format code &#39;f&#39; for object of type &#39;str&#39;. I also tried

  1. ax = sns.heatmap(data.mask(data &lt;= 5), cmap=cmap, annot=kwargs.get(&quot;annot&quot;, True), linewidth=kwargs.get(&quot;linewidth&quot;, None), linecolor=kwargs.get(&quot;linecolor&quot;, None), fmt=kwargs.get(&quot;fmt&quot;, &quot;.0f&quot;), vmin=vmin, vmax=vmax, cbar_kws={&quot;ticks&quot;: kwargs.get(&quot;ticks&quot;, [0, 25, 50, 75, 100])})

but that removes the colors as well which I want to keep (Figure 2).

如何删除低于某个阈值的值,但保留热图中的颜色?
Figure 2: No coloring in the cells that are less than 5.

Is there a way to mask the values that are less than 5 but still keep the coloring?

答案1

得分: 2

以下是代码的翻译部分:

部分1:

获取不需要注释的数据

  1. show_annot_array = np.array(data_norm > 5)

部分2:

隐藏注释

  1. for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
  2. text.set_visible(show_annot)

最终代码

  1. import pandas as pd
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import matplotlib.colors as mplc
  6. np.random.seed(1234)
  7. data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=["A", "B", "C", "D", "E"])
  8. data_norm = data.div(data.sum(axis=1), axis=0) * 100
  9. ################ 部分1 ################
  10. show_annot_array = np.array(data_norm > 5)
  11. ########################################
  12. annot_data = data_norm.applymap(lambda x: x if x >= 5 else None)
  13. def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -> None:
  14. if not isinstance(data, pd.DataFrame):
  15. raise ValueError("data argument must be a pandas DataFrame")
  16. color = kwargs.get("color", ["w", "r"])
  17. cmap = mplc.LinearSegmentedColormap.from_list("", color)
  18. fig, ax = plt.subplots(figsize=kwargs.get("figsize", (16.56/2.54, 8.83/2.54)))
  19. ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get("annot", True), linewidth=kwargs.get("linewidth", None), linecolor=kwargs.get("linecolor", None), fmt=kwargs.get("fmt", ".0f"), vmin=vmin, vmax=vmax, cbar_kws={"ticks": kwargs.get("ticks", [0, 25, 50, 75, 100])})
  20. ax.tick_params(axis="both", which="both", direction="in", top=True, bottom=True, left=True, right=True)
  21. plt.xticks(rotation=kwargs.get("xrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
  22. plt.yticks(rotation=kwargs.get("yrotation", None), ha="right", rotation_mode="anchor", fontsize=kwargs.get("fontsize", 10))
  23. cbar = ax.collections[0].colorbar
  24. cbar.ax.tick_params(labelsize=kwargs.get("labelsize", 10), direction="in", length=6, width=1, color="k", pad=5)
  25. cbar.outline.set_edgecolor("k")
  26. cbar.outline.set_linewidth(1)
  27. cbar.set_label(kwargs.get("cb_label", None), fontsize=kwargs.get("fontsize", 10))
  28. for spine in plt.gca().spines.values():
  29. spine.set_visible(True)
  30. spine.set_edgecolor("k")
  31. spine.set_linewidth(1)
  32. ax.set_xlabel(kwargs.get("xlabel", None), fontsize=kwargs.get("fontsize", 10))
  33. ax.set_ylabel(kwargs.get("ylabel", None), fontsize=kwargs.get("fontsize", 10))
  34. savename = kwargs.get("savename", None)
  35. if savename is not None:
  36. fig.savefig(f"{savename}.{kwargs.get('format', 'svg')}", bbox_inches="tight", dpi=kwargs.get("dpi", None))
  37. ################ 部分2 ################
  38. for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
  39. text.set_visible(show_annot)
  40. ########################################
  41. heatmaplotter(data_norm, annot=annot_data)

希望这有助于你理解代码的翻译。

英文:

you need to add these code:

Part 1:

get data that you don't need annot

  1. show_annot_array = np.array(data_norm &gt; 5)

Part 2:

hide annot

  1. for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
  2. text.set_visible(show_annot)

Final code

  1. import pandas as pd
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import matplotlib.colors as mplc
  6. np.random.seed(1234)
  7. data = pd.DataFrame(np.random.rand(5, 5) * 100, columns=[&quot;A&quot;, &quot;B&quot;, &quot;C&quot;, &quot;D&quot;, &quot;E&quot;])
  8. data_norm = data.div(data.sum(axis=1), axis=0) * 100
  9. ################ Part 1 ################
  10. show_annot_array = np.array(data_norm &gt; 5)
  11. ########################################
  12. annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else None)
  13. # annot_data = data_norm.applymap(lambda x: x if x &gt;= 5 else &quot;&quot;)
  14. def heatmaplotter(data: pd.DataFrame, vmin: int = 0, vmax: int = 100, **kwargs) -&gt; None:
  15. if not isinstance(data, pd.DataFrame):
  16. raise ValueError(&quot;data argument must be a pandas DataFrame&quot;)
  17. color = kwargs.get(&quot;color&quot;, [&quot;w&quot;, &quot;r&quot;])
  18. cmap = mplc.LinearSegmentedColormap.from_list(&quot;&quot;, color)
  19. fig, ax = plt.subplots(figsize=kwargs.get(&quot;figsize&quot;, (16.56/2.54, 8.83/2.54)))
  20. ax = sns.heatmap(data, cmap=cmap, annot=kwargs.get(&quot;annot&quot;, True), linewidth=kwargs.get(&quot;linewidth&quot;, None), linecolor=kwargs.get(&quot;linecolor&quot;, None), fmt=kwargs.get(&quot;fmt&quot;, &quot;.0f&quot;), vmin=vmin, vmax=vmax, cbar_kws={&quot;ticks&quot;: kwargs.get(&quot;ticks&quot;, [0, 25, 50, 75, 100])})
  21. ax.tick_params(axis=&quot;both&quot;, which=&quot;both&quot;, direction=&quot;in&quot;, top=True, bottom=True, left=True, right=True)
  22. plt.xticks(rotation=kwargs.get(&quot;xrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  23. plt.yticks(rotation=kwargs.get(&quot;yrotation&quot;, None), ha=&quot;right&quot;, rotation_mode=&quot;anchor&quot;, fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  24. cbar = ax.collections[0].colorbar
  25. cbar.ax.tick_params(labelsize=kwargs.get(&quot;labelsize&quot;, 10), direction=&quot;in&quot;, length=6, width=1, color=&quot;k&quot;, pad=5)
  26. cbar.outline.set_edgecolor(&quot;k&quot;)
  27. cbar.outline.set_linewidth(1)
  28. cbar.set_label(kwargs.get(&quot;cb_label&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  29. for spine in plt.gca().spines.values():
  30. spine.set_visible(True)
  31. spine.set_edgecolor(&quot;k&quot;)
  32. spine.set_linewidth(1)
  33. ax.set_xlabel(kwargs.get(&quot;xlabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  34. ax.set_ylabel(kwargs.get(&quot;ylabel&quot;, None), fontsize=kwargs.get(&quot;fontsize&quot;, 10))
  35. savename = kwargs.get(&quot;savename&quot;, None)
  36. if savename is not None:
  37. fig.savefig(f&quot;{savename}.{kwargs.get(&#39;format&#39;, &#39;svg&#39;)}&quot;, bbox_inches=&quot;tight&quot;, dpi=kwargs.get(&quot;dpi&quot;, None))
  38. ################ Part 2 ################
  39. for text, show_annot in zip(ax.texts, show_annot_array.ravel()):
  40. text.set_visible(show_annot)
  41. ########################################
  42. heatmaplotter(data_norm, annot=annot_data)

如何删除低于某个阈值的值,但保留热图中的颜色?

huangapple
  • 本文由 发表于 2023年4月17日 15:34:35
  • 转载请务必保留本文链接:https://go.coder-hub.com/76032686.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定