Selfie
Loading...
Searching...
No Matches
WriteTracker.py
Go to the documentation of this file.
1import inspect
2import os
3import threading
4from abc import ABC, abstractmethod
5from functools import total_ordering
6from pathlib import Path
7from typing import Generic, Optional, TypeVar, cast
8
9from .FS import FS
10from .Literals import LiteralString, LiteralTodoStub, LiteralValue, TodoStub
11from .SourceFile import SourceFile
12from .TypedPath import TypedPath
13
14T = TypeVar("T")
15U = TypeVar("U")
16
17
18@total_ordering
20 def __init__(self, file_name: Optional[str], line: int):
21 self._file_name = file_name
22 self._line = line
23
24 @property
25 def file_name(self) -> Optional[str]:
26 return self._file_name
27
28 @property
29 def line(self) -> int:
30 return self._line
31
32 def with_line(self, line: int) -> "CallLocation":
33 return CallLocation(self._file_name, line)
34
35 def ide_link(self, _: "SnapshotFileLayout") -> str:
36 return f"File: {self._file_name}, Line: {self._line}"
37
38 def same_path_as(self, other: "CallLocation") -> bool:
39 if not isinstance(other, CallLocation):
40 return False
41 return self._file_name == other.file_name
42
44 if self._file_name is not None:
45 return self._file_name.rsplit(".", 1)[0]
46 return ""
47
48 def __lt__(self, other) -> bool:
49 if not isinstance(other, CallLocation):
50 return NotImplemented
51 return (self._file_name, self._line) < (other.file_name, other.line)
52
53 def __eq__(self, other) -> bool:
54 if not isinstance(other, CallLocation):
55 return NotImplemented
56 return (self._file_name, self._line) == (other.file_name, other.line)
57
58 def __hash__(self):
59 return hash((self._file_name, self._line))
60
61
63 def __init__(self, location: CallLocation, rest_of_stack: list[CallLocation]):
64 self.location = location
65 self.rest_of_stack = rest_of_stack
66
67 def ide_link(self, layout: "SnapshotFileLayout") -> str:
68 links = [self.location.ide_link(layout)] + [
69 loc.ide_link(layout) for loc in self.rest_of_stack
70 ]
71 return "\n".join(links)
72
73 def __eq__(self, other):
74 if not isinstance(other, CallStack):
75 return NotImplemented
76 return (
77 self.location == other.location
78 and self.rest_of_stack == other.rest_of_stack
79 )
80
81 def __hash__(self):
82 return hash((self.location, tuple(self.rest_of_stack)))
83
84
86 def __init__(self, fs: FS):
87 self.fs = fs
88
89 @abstractmethod
90 def root_folder(self) -> TypedPath:
91 pass
92
93 def sourcefile_for_call(self, call: CallLocation) -> TypedPath:
94 file_path = call.file_name
95 if not file_path:
96 raise ValueError("No file path available in CallLocation.")
97 return TypedPath(os.path.abspath(Path(file_path)))
98
99
100def recordCall(callerFileOnly: bool) -> CallStack:
101 stack_frames_raw = inspect.stack()
102 first_real_frame = next(
103 (
104 i
105 for i, x in enumerate(stack_frames_raw)
106 if x.frame.f_globals.get("__package__") != __package__
107 ),
108 None,
109 )
110 # filter to only the stack after the selfie-lib package
111 stack_frames = stack_frames_raw[first_real_frame:]
112
113 if callerFileOnly:
114 caller_file = stack_frames[0].filename
115 stack_frames = [
116 frame for frame in stack_frames if frame.filename == caller_file
117 ]
118
119 call_locations = [
120 CallLocation(frame.filename, frame.lineno) for frame in stack_frames
121 ]
122
123 location = call_locations[0]
124 rest_of_stack = call_locations[1:]
125
126 return CallStack(location, rest_of_stack)
127
128
129class FirstWrite(Generic[U]):
130 def __init__(self, snapshot: U, call_stack: CallStack):
131 self.snapshot = snapshot
132 self.call_stack = call_stack
133
134
135class WriteTracker(ABC, Generic[T, U]):
136 def __init__(self):
137 self.lock = threading.Lock()
138 self.writes: dict[T, FirstWrite[U]] = {}
139
141 self,
142 key: T,
143 snapshot: U,
144 call: CallStack,
145 layout: SnapshotFileLayout,
146 allow_multiple_equivalent_writes: bool = True,
147 ):
148 with self.lock:
149 this_write = FirstWrite(snapshot, call)
150 if key not in self.writes:
151 self.writes[key] = this_write
152 return
153
154 existing = self.writes[key]
155 if existing.snapshot != snapshot:
156 raise ValueError(
157 f"Snapshot was set to multiple values!\n first time: {existing.call_stack.location.ide_link(layout)}\n this time: {call.location.ide_link(layout)}"
158 )
159 elif not allow_multiple_equivalent_writes:
160 raise ValueError("Snapshot was set to the same value multiple times.")
161
162
164 def record(self, key: T, snapshot: U, call: CallStack, layout: SnapshotFileLayout):
165 super().recordInternal(key, snapshot, call, layout)
166
167
168class InlineWriteTracker(WriteTracker[CallLocation, LiteralValue]):
169 def hasWrites(self) -> bool:
170 return len(self.writes) > 0
171
173 self,
174 snapshot: LiteralValue,
175 call: CallStack,
176 layout: SnapshotFileLayout,
177 ):
178 super().recordInternal(call.location, snapshot, call, layout)
179
180 file = layout.sourcefile_for_call(call.location)
181
182 if (
183 snapshot.expected is not None
184 and isinstance(snapshot.expected, str)
185 and isinstance(snapshot.format, LiteralString)
186 ):
187 content = SourceFile(file.name, layout.fs.file_read(file))
188 try:
189 snapshot = cast(LiteralValue, snapshot)
190 parsed_value = content.parse_to_be_like(
191 call.location.line
192 ).parse_literal(snapshot.format)
193 except Exception as e:
194 raise AssertionError(
195 f"Error while parsing the literal at {call.location.ide_link(layout)}. Please report this error at https://github.com/diffplug/selfie"
196 ) from e
197 if parsed_value != snapshot.expected:
198 raise layout.fs.assert_failed(
199 f"Selfie cannot modify the literal at {call.location.ide_link(layout)} because Selfie has a parsing bug. Please report this error at https://github.com/diffplug/selfie",
200 snapshot.expected,
201 parsed_value,
202 )
203
204 def persist_writes(self, layout: SnapshotFileLayout):
205 # Assuming there is at least one write to process
206 if not self.writes:
207 return
208
209 # Sorting writes based on file name and line number
210 sorted_writes = sorted(
211 self.writes.values(),
212 key=lambda x: (x.call_stack.location.file_name, x.call_stack.location.line),
213 )
214
215 # Initialize from the first write
216 first_write = sorted_writes[0]
217 current_file = layout.sourcefile_for_call(first_write.call_stack.location)
218 content = SourceFile(current_file.name, layout.fs.file_read(current_file))
219 delta_line_numbers = 0
220
221 for write in sorted_writes:
222 # Determine the file path for the current write
223 file_path = layout.sourcefile_for_call(write.call_stack.location)
224 # If we switch to a new file, write changes to the disk for the previous file
225 if file_path != current_file:
226 layout.fs.file_write(current_file, content.as_string)
227 current_file = file_path
228 content = SourceFile(
229 current_file.name, layout.fs.file_read(current_file)
230 )
231 delta_line_numbers = 0
232
233 # Calculate the line number taking into account changes that shifted line numbers
234 line = write.call_stack.location.line + delta_line_numbers
235 if isinstance(write.snapshot.format, LiteralTodoStub):
236 kind: TodoStub = write.snapshot.actual # type: ignore
237 content.replace_on_line(line, f".{kind.name}_TODO(", f".{kind.name}(")
238 else:
239 to_be_literal = content.parse_to_be_like(line)
240 # Attempt to set the literal value and adjust for line shifts due to content changes
241 literal_change = to_be_literal.set_literal_and_get_newline_delta(
242 write.snapshot
243 )
244 delta_line_numbers += literal_change
245
246 # Final write to disk for the last file processed
247 layout.fs.file_write(current_file, content.as_string)
248
249
251 def __init__(self, location: TypedPath, layout: SnapshotFileLayout, data: bytes):
252 self.location = location
253 self.layout = layout
254 self.data = data
255
256 def writeToDisk(self) -> None:
257 if self.data is None:
258 raise Exception("Data has already been written to disk")
259 self.layout.fs.file_write_binary(self.location, self.data)
260 self.data = None # Allow garbage collection
261
262 def readData(self):
263 if self.data is not None:
264 return self.data
265 return self.layout.fs.file_read_binary(self.location)
266
267 def __eq__(self, other):
268 if not isinstance(other, ToBeFileLazyBytes):
269 return False
270 return self.readData() == other.readData()
271
272 def __hash__(self):
273 return hash(self.readData())
274
275
276class ToBeFileWriteTracker(WriteTracker[TypedPath, ToBeFileLazyBytes]):
277 def __init__(self):
278 super().__init__()
279
281 self,
282 key: TypedPath,
283 snapshot: bytes,
284 call: CallStack,
285 layout: SnapshotFileLayout,
286 ) -> None:
287 lazyBytes = ToBeFileLazyBytes(key, layout, snapshot)
288 self.recordInternal(key, lazyBytes, call, layout)
289 lazyBytes.writeToDisk()
bool same_path_as(self, "CallLocation" other)
__init__(self, Optional[str] file_name, int line)
str ide_link(self, "SnapshotFileLayout" _)
"CallLocation" with_line(self, int line)
str ide_link(self, "SnapshotFileLayout" layout)
__init__(self, CallLocation location, list[CallLocation] rest_of_stack)
record(self, T key, U snapshot, CallStack call, SnapshotFileLayout layout)
__init__(self, U snapshot, CallStack call_stack)
persist_writes(self, SnapshotFileLayout layout)
record(self, LiteralValue snapshot, CallStack call, SnapshotFileLayout layout)
TypedPath sourcefile_for_call(self, CallLocation call)
__init__(self, TypedPath location, SnapshotFileLayout layout, bytes data)
None writeToDisk(self, TypedPath key, bytes snapshot, CallStack call, SnapshotFileLayout layout)
recordInternal(self, T key, U snapshot, CallStack call, SnapshotFileLayout layout, bool allow_multiple_equivalent_writes=True)