Coverage for moptipy / evaluation / styler.py: 78%
134 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
1"""Styler allows to discover groups of data and associate styles with them."""
2from typing import Any, Callable, Final, Iterable, cast
4from matplotlib.artist import Artist # type: ignore
5from matplotlib.lines import Line2D # type: ignore
6from pycommons.types import type_error
8from moptipy.utils.plot_defaults import create_line_style
11class Styler:
12 """A class for determining groups of elements and styling them."""
14 #: The tuple with the names becomes valid after compilation.
15 names: tuple[str, ...]
16 #: The tuple with the keys becomes valid after compilation.
17 keys: tuple[Any, ...]
18 #: The dictionary mapping keys to indices; only valid after compilation.
19 __indexes: dict[Any, int]
20 #: Is there a None key? Valid after compilation.
21 has_none: bool
22 #: The number of registered keys.
23 count: int
24 #: Does this styler have any style associated with it?
25 has_style: bool
27 def __init__(self,
28 key_func: Callable = lambda x: x,
29 namer: Callable[[Any], str] = str,
30 none_name: str = "None",
31 priority: int | float = 0,
32 name_sort_function: Callable[[str], Any] | None =
33 lambda s: s):
34 """
35 Initialize the style grouper.
37 :param key_func: the key function, obtaining keys from objects
38 :param namer: the name function, turning keys into names
39 :param none_name: the name for the none-key
40 :param priority: the base priority of this grouper
41 :param name_sort_function: the function for sorting names, or `None`
42 if no name-based sorting shall be performed
43 """
44 if not callable(key_func):
45 raise type_error(key_func, "key function", call=True)
47 if not callable(namer):
48 raise type_error(namer, "namer function", call=True)
49 if not isinstance(none_name, str):
50 raise type_error(none_name, "none_name", str)
51 if not isinstance(priority, float | int):
52 raise type_error(priority, "priority", (int, float))
53 if (name_sort_function is not None) \
54 and (not callable(name_sort_function)):
55 raise type_error(name_sort_function, "name_sort_function",
56 type(None), call=True)
58 def __namer(key,
59 __namer: Callable[[Any], str] = namer,
60 __none_name: str = none_name) -> str:
61 rv = __none_name if key is None else __namer(key)
62 if not isinstance(rv, str):
63 raise type_error(rv, f"name for key {key!r}", str)
64 rv = rv.strip()
65 if len(rv) <= 0:
66 raise ValueError(
67 "name cannot be empty or just consist of white space")
68 return rv
70 #: the name sort function
71 self.__name_sort_function: Final[Callable[[str], Any] | None] = \
72 name_sort_function
73 #: The key function of the grouper
74 self.key_func: Final[Callable] = key_func
75 #: The name function of the grouper
76 self.name_func: Final[Callable[[Any], str]] = \
77 cast("Callable[[Any], str]", __namer)
78 #: The base priority of this grouper
79 self.priority: float = float(priority)
80 #: The internal collection.
81 self.__collection: set = set()
82 #: the line colors
83 self.__line_colors: tuple | None = None
84 #: the line dashes
85 self.__line_dashes: tuple | None = None
86 #: the line widths
87 self.__line_widths: tuple[float, ...] | None = None
88 #: the optional line alpha
89 self.__line_alphas: tuple[float, ...] | None = None
91 def add(self, obj) -> None:
92 """
93 Add an object to the style collection.
95 :param obj: the object
96 """
97 self.__collection.add(self.key_func(obj))
99 def finalize(self) -> None:
100 """Compile the styler collection."""
101 self.has_none = (None in self.__collection)
102 if self.has_none:
103 self.__collection.remove(None)
105 nsf: Final[Callable[[str], Any] | None] = self.__name_sort_function
106 if nsf is None:
107 data = [(k, self.name_func(k)) for k in self.__collection]
108 data.sort()
109 if self.has_none:
110 data.insert(0, (None, self.name_func(None)))
111 self.keys = tuple(x[0] for x in data)
112 self.names = tuple(x[1] for x in data)
113 else:
114 data = [(self.name_func(k), k) for k in self.__collection]
115 data.sort(key=cast("Callable", lambda x, nsf2=nsf: nsf2(x[0])))
116 if self.has_none:
117 data.insert(0, (self.name_func(None), None))
118 self.names = tuple(x[0] for x in data)
119 self.keys = tuple(x[1] for x in data)
121 del self.__collection
122 del data
123 self.__indexes = {k: i for i, k in enumerate(self.keys)}
124 self.count = len(self.names)
125 self.priority += self.count
126 self.has_style = False
128 def __lt__(self, other) -> bool:
129 """
130 Check whether this styler is more important than another one.
132 :param other: the other styler
133 :return: `True` if it is, `False` if it is not.
134 """
135 if self.priority > other.priority:
136 return True
137 if self.priority < other.priority:
138 return False
139 c1 = self.count
140 if self.has_none:
141 c1 -= 1
142 c2 = other.count
143 if other.has_none:
144 c2 -= 1
145 return c1 > c2
147 def set_line_color(self, line_color_func: Callable) -> None:
148 """
149 Set that this styler should apply line colors.
151 :param line_color_func: a function returning the palette
152 """
153 tmp = line_color_func(self.count)
154 if not isinstance(tmp, Iterable):
155 raise type_error(tmp, "result of line color func", Iterable)
156 self.__line_colors = tuple(tmp)
157 if len(self.__line_colors) != self.count:
158 raise ValueError(f"There must be {self.count} line colors,"
159 f"but found only {len(self.__line_colors)}.")
160 self.has_style = True
162 def set_line_dash(self, line_dash_func: Callable) -> None:
163 """
164 Set that this styler should apply line dashes.
166 :param line_dash_func: a function returning the dashes
167 """
168 tmp = line_dash_func(self.count)
169 if not isinstance(tmp, Iterable):
170 raise type_error(tmp, "result of line dash func", Iterable)
171 self.__line_dashes = tuple(tmp)
172 if len(self.__line_dashes) != self.count:
173 raise ValueError(f"There must be {self.count} line dashes,"
174 f"but found only {len(self.__line_dashes)}.")
175 self.has_style = True
177 def set_line_width(self, line_width_func: Callable) -> None:
178 """
179 Set that this styler should apply a line width.
181 :param line_width_func: the line width function
182 """
183 tmp = line_width_func(self.count)
184 if not isinstance(tmp, Iterable):
185 raise type_error(tmp, "result of line width func", Iterable)
186 self.__line_widths = tuple(tmp)
187 if len(self.__line_widths) != self.count:
188 raise ValueError(f"There must be {self.count} line widths,"
189 f"but found only {len(self.__line_widths)}.")
190 self.has_style = True
192 def set_line_alpha(self, line_alpha_func: Callable) -> None:
193 """
194 Set that this styler should apply a line alpha.
196 :param line_alpha_func: the line alpha function
197 """
198 tmp = line_alpha_func(self.count)
199 if not isinstance(tmp, Iterable):
200 raise type_error(tmp, "result of line alpha func", Iterable)
201 self.__line_alphas = tuple(tmp)
202 if len(self.__line_alphas) != self.count:
203 raise ValueError(f"There must be {self.count} line alphas,"
204 f"but found only {len(self.__line_alphas)}.")
205 self.has_style = True
207 def add_line_style(self, obj,
208 style: dict[str, object]) -> None:
209 """
210 Apply this styler's contents based on the given object.
212 :param obj: the object for which the style should be created
213 :param style: the decode to which the styles should be added
214 """
215 key = self.key_func(obj)
216 index = self.__indexes.setdefault(key, -1)
217 if index >= 0:
218 self.__add_line_style(index, style)
220 def __add_line_style(self, index,
221 style: dict[str, object]) -> None:
222 """
223 Apply this styler's contents based on the given object.
225 :param index: the index to be processed
226 :param style: the decode to which the styles should be added
227 """
228 if self.__line_colors is not None:
229 style["color"] = self.__line_colors[index]
230 if self.__line_dashes is not None:
231 style["linestyle"] = self.__line_dashes[index]
232 if self.__line_widths is not None:
233 style["linewidth"] = self.__line_widths[index]
234 if self.__line_alphas is not None:
235 style["alpha"] = self.__line_alphas[index]
237 def add_to_legend(self, consumer: Callable[[Artist], Any]) -> None:
238 """
239 Add this styler to the legend.
241 :param consumer: the consumer to add to
242 """
243 if not callable(consumer):
244 raise type_error(consumer, "consumer", call=True)
245 for i, name in enumerate(self.names):
246 style = create_line_style()
247 self.__add_line_style(i, style)
248 style["label"] = name
249 style["xdata"] = []
250 style["ydata"] = []
251 consumer(Line2D(**style)) # type: ignore