1+ import json
2+ import os
13from collections import defaultdict
24from dataclasses import (
35 dataclass ,
@@ -46,6 +48,20 @@ class ItemStatus(Enum):
4648 SKIPPED = "skipped"
4749
4850
51+ class _FakePytestObject :
52+ def __init__ (self , collected_item : dict [str , str ]) -> None :
53+ self .__module__ = collected_item ["modulename" ]
54+ self .__name__ = collected_item ["methodname" ]
55+
56+
57+ class _FakePytestItem :
58+ def __init__ (self , collected_item : dict [str , str ]) -> None :
59+ self .nodeid = collected_item ["nodeid" ]
60+ self .name = collected_item ["name" ]
61+ self .path = Path (collected_item ["path" ])
62+ self .obj = _FakePytestObject (collected_item )
63+
64+
4965@dataclass
5066class SnapshotSession :
5167 pytest_session : "pytest.Session"
@@ -127,6 +143,24 @@ def ran_item(
127143 except ValueError :
128144 pass # if we don't understand the outcome, leave the item as "not run"
129145
146+ def _merge_collected_items (self , collected_items : list [dict [str , str ]]) -> None :
147+ for collected_item in collected_items :
148+ custom_item = _FakePytestItem (collected_item )
149+ if not any (
150+ t .nodeid == custom_item .nodeid and t .name == custom_item .nodeid
151+ for t in self ._collected_items
152+ ):
153+ self ._collected_items .add (custom_item ) # type: ignore[arg-type]
154+
155+ def _merge_selected_items (self , selected_items : dict [str , str ]) -> None :
156+ for key , selected_item in selected_items .items ():
157+ if key in self ._selected_items :
158+ status = ItemStatus (selected_item )
159+ if status != ItemStatus .NOT_RUN :
160+ self ._selected_items [key ] = status
161+ else :
162+ self ._selected_items [key ] = ItemStatus (selected_item )
163+
130164 def finish (self ) -> int :
131165 exitstatus = 0
132166 self .flush_snapshot_write_queue ()
@@ -139,16 +173,39 @@ def finish(self) -> int:
139173 )
140174
141175 if is_xdist_worker ():
142- # TODO: If we're in a pytest-xdist worker, we need to combine the reports
143- # of all the workers so that the controller can handle unused
144- # snapshot removal.
176+ worker_count = os .getenv ("PYTEST_XDIST_WORKER_COUNT" )
177+ with open (".pytest_syrupy_worker_count" , "w" , encoding = "utf-8" ) as f :
178+ f .write (worker_count ) # type: ignore[arg-type]
179+ with open (
180+ f".pytest_syrupy_{ os .getenv ("PYTEST_XDIST_WORKER" )} _result" ,
181+ "w" ,
182+ encoding = "utf-8" ,
183+ ) as f :
184+ json .dump (self .report .serialize (), f , indent = 2 )
145185 return exitstatus
146186 elif is_xdist_controller ():
147187 # TODO: If we're in a pytest-xdist controller, merge all the reports.
148188 # Until this is implemented, running syrupy with pytest-xdist is only
149189 # partially functional.
150190 return exitstatus
151191
192+ worker_count = None
193+ try :
194+ with open (".pytest_syrupy_worker_count" , encoding = "utf-8" ) as f :
195+ worker_count = f .read ()
196+ os .remove (".pytest_syrupy_worker_count" )
197+ except FileNotFoundError :
198+ pass
199+
200+ if worker_count :
201+ for i in range (int (worker_count )):
202+ with open (f".pytest_syrupy_gw{ i } _result" , encoding = "utf-8" ) as f :
203+ data = json .load (f )
204+ self ._merge_collected_items (data ["_collected_items" ])
205+ self ._merge_selected_items (data ["_selected_items" ])
206+ self .report .merge_serialized (data )
207+ os .remove (f".pytest_syrupy_gw{ i } _result" )
208+
152209 if self .report .num_unused :
153210 if self .update_snapshots :
154211 self .remove_unused_snapshots (
0 commit comments