|
7 | 7 | # pyre-unsafe
|
8 | 8 |
|
9 | 9 | import copy
|
| 10 | +import torch |
10 | 11 | import random
|
11 | 12 | import statistics
|
12 | 13 | import tempfile
|
13 | 14 | import unittest
|
14 | 15 | from contextlib import redirect_stdout
|
15 | 16 |
|
16 |
| -from typing import Callable, List |
| 17 | +from typing import Callable, List, Union |
17 | 18 |
|
18 | 19 | from unittest.mock import patch
|
19 | 20 |
|
|
56 | 57 |
|
57 | 58 | OP_TYPE = "aten::add"
|
58 | 59 | EVENT_BLOCK_NAME = "block_0"
|
59 |
| -EVENTS_SIZE = 5 |
| 60 | +EVENTS_SIZE = 10 |
60 | 61 | RAW_DATA_SIZE = 10
|
61 | 62 | ETDUMP_PATH = "unittest_etdump_path"
|
62 | 63 | ETRECORD_PATH = "unittest_etrecord_path"
|
@@ -535,17 +536,111 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
|
535 | 536 | )
|
536 | 537 | )
|
537 | 538 |
|
| 539 | + def test_get_runtime_intermediate_outputs(self): |
| 540 | + # Create a context manager to patch functions called by Inspector.__init__ |
| 541 | + with patch.object( |
| 542 | + _inspector, "parse_etrecord", return_value=None |
| 543 | + ), patch.object( |
| 544 | + _inspector, "gen_etdump_object", return_value=None |
| 545 | + ), patch.object( |
| 546 | + EventBlock, "_gen_from_etdump" |
| 547 | + ), patch.object( |
| 548 | + _inspector, "gen_graphs_from_etrecord" |
| 549 | + ): |
| 550 | + # Call the constructor of Inspector |
| 551 | + inspector_instance = Inspector( |
| 552 | + etdump_path=ETDUMP_PATH, |
| 553 | + etrecord=ETRECORD_PATH, |
| 554 | + ) |
| 555 | + |
| 556 | + # The mock inspector instance starts with having an empty event blocks list. |
| 557 | + # Add pre-defined event blocks to test _get_runtime_outputs(). |
| 558 | + inspector_instance.event_blocks = [ |
| 559 | + EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events()) |
| 560 | + ] |
| 561 | + |
| 562 | + runtime_outputs = inspector_instance._get_runtime_intermediate_outputs() |
| 563 | + # This output should be a dictionary with 5 keys |
| 564 | + self.assertEqual(len(runtime_outputs), 5, ) |
| 565 | + # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty) |
| 566 | + self.assertNotIn((0,), runtime_outputs) |
| 567 | + self.assertNotIn((1,), runtime_outputs) |
| 568 | + |
| 569 | + # Same debug_handle but different instruction_id, should record the last one |
| 570 | + self.assertIn((4,), runtime_outputs) |
| 571 | + self.assertTrue(torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))) |
| 572 | + # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size |
| 573 | + for key in range(5, 9): |
| 574 | + self.assertIn((key,), runtime_outputs) |
| 575 | + self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) |
| 576 | + |
538 | 577 | def _gen_random_float_list(self) -> List[float]:
|
539 | 578 | return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
|
540 | 579 |
|
| 580 | + def _gen_random_runtime_output(self) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: |
| 581 | + return list(torch.randn(RAW_DATA_SIZE)) |
| 582 | + |
541 | 583 | def _gen_random_events(self) -> List[Event]:
|
542 | 584 | events = []
|
543 |
| - for i in range(EVENTS_SIZE): |
| 585 | + for i in range(2): |
| 586 | + events.append( |
| 587 | + # OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2 |
| 588 | + Event( |
| 589 | + name="OPERATOR_CALL", |
| 590 | + op_types=[OP_TYPE], |
| 591 | + perf_data=PerfData(self._gen_random_float_list()), |
| 592 | + debug_handles = i * 2, |
| 593 | + _instruction_id = i * 2, |
| 594 | + debug_data = self._gen_random_runtime_output() |
| 595 | + ) |
| 596 | + ) |
| 597 | + events.append( |
| 598 | + # op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3 |
| 599 | + Event( |
| 600 | + name=f"op_{i}", |
| 601 | + op_types=[], |
| 602 | + perf_data=PerfData(self._gen_random_float_list()), |
| 603 | + debug_handles = i * 2 + 1, |
| 604 | + _instruction_id = i * 2 + 1, |
| 605 | + debug_data = self._gen_random_runtime_output() |
| 606 | + ) |
| 607 | + ) |
| 608 | + |
| 609 | + # op_2 with debug_hanldes/instruction_id 4 |
| 610 | + events.append( |
| 611 | + Event( |
| 612 | + name=f"op_2", |
| 613 | + op_types=[OP_TYPE], |
| 614 | + perf_data=PerfData(self._gen_random_float_list()), |
| 615 | + debug_handles = 4, |
| 616 | + debug_data = [torch.tensor([1.0, 2.0, 3.0])], |
| 617 | + _instruction_id = 4 |
| 618 | + |
| 619 | + ) |
| 620 | + ) |
| 621 | + # op_3 also with debug_hanldes 4 but with instruction_id 5 |
| 622 | + events.append( |
| 623 | + Event( |
| 624 | + name=f"op_3", |
| 625 | + op_types=[OP_TYPE], |
| 626 | + perf_data=PerfData(self._gen_random_float_list()), |
| 627 | + debug_handles = 4, |
| 628 | + debug_data = [torch.tensor([4.0, 5.0, 6.0])], |
| 629 | + _instruction_id = 5 |
| 630 | + |
| 631 | + ) |
| 632 | + ) |
| 633 | + |
| 634 | + # op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9 |
| 635 | + for i in range(4, EVENTS_SIZE - 2): |
544 | 636 | events.append(
|
545 | 637 | Event(
|
546 | 638 | name=f"op_{i}",
|
547 | 639 | op_types=[OP_TYPE],
|
548 | 640 | perf_data=PerfData(self._gen_random_float_list()),
|
| 641 | + debug_handles = i + 1, |
| 642 | + debug_data = self._gen_random_runtime_output(), |
| 643 | + _instruction_id = i + 2 |
549 | 644 | )
|
550 | 645 | )
|
551 | 646 | return events
|
0 commit comments