Coverage for moptipy / evaluation / styler.py: 78%

134 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-24 08:49 +0000

1"""Styler allows to discover groups of data and associate styles with them.""" 

2from typing import Any, Callable, Final, Iterable, cast 

3 

4from matplotlib.artist import Artist # type: ignore 

5from matplotlib.lines import Line2D # type: ignore 

6from pycommons.types import type_error 

7 

8from moptipy.utils.plot_defaults import create_line_style 

9 

10 

11class Styler: 

12 """A class for determining groups of elements and styling them.""" 

13 

14 #: The tuple with the names becomes valid after compilation. 

15 names: tuple[str, ...] 

16 #: The tuple with the keys becomes valid after compilation. 

17 keys: tuple[Any, ...] 

18 #: The dictionary mapping keys to indices; only valid after compilation. 

19 __indexes: dict[Any, int] 

20 #: Is there a None key? Valid after compilation. 

21 has_none: bool 

22 #: The number of registered keys. 

23 count: int 

24 #: Does this styler have any style associated with it? 

25 has_style: bool 

26 

27 def __init__(self, 

28 key_func: Callable = lambda x: x, 

29 namer: Callable[[Any], str] = str, 

30 none_name: str = "None", 

31 priority: int | float = 0, 

32 name_sort_function: Callable[[str], Any] | None = 

33 lambda s: s): 

34 """ 

35 Initialize the style grouper. 

36 

37 :param key_func: the key function, obtaining keys from objects 

38 :param namer: the name function, turning keys into names 

39 :param none_name: the name for the none-key 

40 :param priority: the base priority of this grouper 

41 :param name_sort_function: the function for sorting names, or `None` 

42 if no name-based sorting shall be performed 

43 """ 

44 if not callable(key_func): 

45 raise type_error(key_func, "key function", call=True) 

46 

47 if not callable(namer): 

48 raise type_error(namer, "namer function", call=True) 

49 if not isinstance(none_name, str): 

50 raise type_error(none_name, "none_name", str) 

51 if not isinstance(priority, float | int): 

52 raise type_error(priority, "priority", (int, float)) 

53 if (name_sort_function is not None) \ 

54 and (not callable(name_sort_function)): 

55 raise type_error(name_sort_function, "name_sort_function", 

56 type(None), call=True) 

57 

58 def __namer(key, 

59 __namer: Callable[[Any], str] = namer, 

60 __none_name: str = none_name) -> str: 

61 rv = __none_name if key is None else __namer(key) 

62 if not isinstance(rv, str): 

63 raise type_error(rv, f"name for key {key!r}", str) 

64 rv = rv.strip() 

65 if len(rv) <= 0: 

66 raise ValueError( 

67 "name cannot be empty or just consist of white space") 

68 return rv 

69 

70 #: the name sort function 

71 self.__name_sort_function: Final[Callable[[str], Any] | None] = \ 

72 name_sort_function 

73 #: The key function of the grouper 

74 self.key_func: Final[Callable] = key_func 

75 #: The name function of the grouper 

76 self.name_func: Final[Callable[[Any], str]] = \ 

77 cast("Callable[[Any], str]", __namer) 

78 #: The base priority of this grouper 

79 self.priority: float = float(priority) 

80 #: The internal collection. 

81 self.__collection: set = set() 

82 #: the line colors 

83 self.__line_colors: tuple | None = None 

84 #: the line dashes 

85 self.__line_dashes: tuple | None = None 

86 #: the line widths 

87 self.__line_widths: tuple[float, ...] | None = None 

88 #: the optional line alpha 

89 self.__line_alphas: tuple[float, ...] | None = None 

90 

91 def add(self, obj) -> None: 

92 """ 

93 Add an object to the style collection. 

94 

95 :param obj: the object 

96 """ 

97 self.__collection.add(self.key_func(obj)) 

98 

99 def finalize(self) -> None: 

100 """Compile the styler collection.""" 

101 self.has_none = (None in self.__collection) 

102 if self.has_none: 

103 self.__collection.remove(None) 

104 

105 nsf: Final[Callable[[str], Any] | None] = self.__name_sort_function 

106 if nsf is None: 

107 data = [(k, self.name_func(k)) for k in self.__collection] 

108 data.sort() 

109 if self.has_none: 

110 data.insert(0, (None, self.name_func(None))) 

111 self.keys = tuple(x[0] for x in data) 

112 self.names = tuple(x[1] for x in data) 

113 else: 

114 data = [(self.name_func(k), k) for k in self.__collection] 

115 data.sort(key=cast("Callable", lambda x, nsf2=nsf: nsf2(x[0]))) 

116 if self.has_none: 

117 data.insert(0, (self.name_func(None), None)) 

118 self.names = tuple(x[0] for x in data) 

119 self.keys = tuple(x[1] for x in data) 

120 

121 del self.__collection 

122 del data 

123 self.__indexes = {k: i for i, k in enumerate(self.keys)} 

124 self.count = len(self.names) 

125 self.priority += self.count 

126 self.has_style = False 

127 

128 def __lt__(self, other) -> bool: 

129 """ 

130 Check whether this styler is more important than another one. 

131 

132 :param other: the other styler 

133 :return: `True` if it is, `False` if it is not. 

134 """ 

135 if self.priority > other.priority: 

136 return True 

137 if self.priority < other.priority: 

138 return False 

139 c1 = self.count 

140 if self.has_none: 

141 c1 -= 1 

142 c2 = other.count 

143 if other.has_none: 

144 c2 -= 1 

145 return c1 > c2 

146 

147 def set_line_color(self, line_color_func: Callable) -> None: 

148 """ 

149 Set that this styler should apply line colors. 

150 

151 :param line_color_func: a function returning the palette 

152 """ 

153 tmp = line_color_func(self.count) 

154 if not isinstance(tmp, Iterable): 

155 raise type_error(tmp, "result of line color func", Iterable) 

156 self.__line_colors = tuple(tmp) 

157 if len(self.__line_colors) != self.count: 

158 raise ValueError(f"There must be {self.count} line colors," 

159 f"but found only {len(self.__line_colors)}.") 

160 self.has_style = True 

161 

162 def set_line_dash(self, line_dash_func: Callable) -> None: 

163 """ 

164 Set that this styler should apply line dashes. 

165 

166 :param line_dash_func: a function returning the dashes 

167 """ 

168 tmp = line_dash_func(self.count) 

169 if not isinstance(tmp, Iterable): 

170 raise type_error(tmp, "result of line dash func", Iterable) 

171 self.__line_dashes = tuple(tmp) 

172 if len(self.__line_dashes) != self.count: 

173 raise ValueError(f"There must be {self.count} line dashes," 

174 f"but found only {len(self.__line_dashes)}.") 

175 self.has_style = True 

176 

177 def set_line_width(self, line_width_func: Callable) -> None: 

178 """ 

179 Set that this styler should apply a line width. 

180 

181 :param line_width_func: the line width function 

182 """ 

183 tmp = line_width_func(self.count) 

184 if not isinstance(tmp, Iterable): 

185 raise type_error(tmp, "result of line width func", Iterable) 

186 self.__line_widths = tuple(tmp) 

187 if len(self.__line_widths) != self.count: 

188 raise ValueError(f"There must be {self.count} line widths," 

189 f"but found only {len(self.__line_widths)}.") 

190 self.has_style = True 

191 

192 def set_line_alpha(self, line_alpha_func: Callable) -> None: 

193 """ 

194 Set that this styler should apply a line alpha. 

195 

196 :param line_alpha_func: the line alpha function 

197 """ 

198 tmp = line_alpha_func(self.count) 

199 if not isinstance(tmp, Iterable): 

200 raise type_error(tmp, "result of line alpha func", Iterable) 

201 self.__line_alphas = tuple(tmp) 

202 if len(self.__line_alphas) != self.count: 

203 raise ValueError(f"There must be {self.count} line alphas," 

204 f"but found only {len(self.__line_alphas)}.") 

205 self.has_style = True 

206 

207 def add_line_style(self, obj, 

208 style: dict[str, object]) -> None: 

209 """ 

210 Apply this styler's contents based on the given object. 

211 

212 :param obj: the object for which the style should be created 

213 :param style: the decode to which the styles should be added 

214 """ 

215 key = self.key_func(obj) 

216 index = self.__indexes.setdefault(key, -1) 

217 if index >= 0: 

218 self.__add_line_style(index, style) 

219 

220 def __add_line_style(self, index, 

221 style: dict[str, object]) -> None: 

222 """ 

223 Apply this styler's contents based on the given object. 

224 

225 :param index: the index to be processed 

226 :param style: the decode to which the styles should be added 

227 """ 

228 if self.__line_colors is not None: 

229 style["color"] = self.__line_colors[index] 

230 if self.__line_dashes is not None: 

231 style["linestyle"] = self.__line_dashes[index] 

232 if self.__line_widths is not None: 

233 style["linewidth"] = self.__line_widths[index] 

234 if self.__line_alphas is not None: 

235 style["alpha"] = self.__line_alphas[index] 

236 

237 def add_to_legend(self, consumer: Callable[[Artist], Any]) -> None: 

238 """ 

239 Add this styler to the legend. 

240 

241 :param consumer: the consumer to add to 

242 """ 

243 if not callable(consumer): 

244 raise type_error(consumer, "consumer", call=True) 

245 for i, name in enumerate(self.names): 

246 style = create_line_style() 

247 self.__add_line_style(i, style) 

248 style["label"] = name 

249 style["xdata"] = [] 

250 style["ydata"] = [] 

251 consumer(Line2D(**style)) # type: ignore