diff --git a/ccsdspy/packet_types.py b/ccsdspy/packet_types.py index d07ec18..123416f 100644 --- a/ccsdspy/packet_types.py +++ b/ccsdspy/packet_types.py @@ -141,7 +141,7 @@ def __init__(self, fields): self._init(fields) - def load(self, file, include_primary_header=False): + def load(self, file, include_primary_header=False, reset_file_obj=False): """Decode a file-like object containing a sequence of these packets. Parameters @@ -153,6 +153,10 @@ def load(self, file, include_primary_header=False): fields are: `CCSDS_VERSION_NUMBER`, `CCSDS_PACKET_TYPE`, `CCSDS_SECONDARY_FLAG`, `CCSDS_SEQUENCE_FLAG`, `CCSDS_APID`, `CCSDS_SEQUENCE_COUNT`, and `CCSDS_PACKET_LENGTH` + reset_file_obj : bool + If True, leave the file object, when it is file buffer, where it was before load is called. + Otherwise, (default), leave the file stream pos after the read packets. + Does not apply when file is a string. Returns ------- @@ -175,6 +179,7 @@ def load(self, file, include_primary_header=False): self._converters, "fixed_length", include_primary_header=True, + reset_file_obj=reset_file_obj, ) # inspect the primary header and issue warning if appropriate @@ -262,7 +267,7 @@ def __init__(self, fields): self._init(fields) - def load(self, file, include_primary_header=False): + def load(self, file, include_primary_header=False, reset_file_obj=False): """Decode a file-like object containing a sequence of these packets. Parameters @@ -274,6 +279,10 @@ def load(self, file, include_primary_header=False): fields are: `CCSDS_VERSION_NUMBER`, `CCSDS_PACKET_TYPE`, `CCSDS_SECONDARY_FLAG`, `CCSDS_SEQUENCE_FLAG`, `CCSDS_APID`, `CCSDS_SEQUENCE_COUNT`, and `CCSDS_PACKET_LENGTH` + reset_file_obj : bool + If True, leave the file object, when it is file buffer, where it was before load is called. + Otherwise, (default), leave the file stream pos after the read packets. + Does not apply when file is a string. Returns ------- @@ -294,7 +303,12 @@ def load(self, file, include_primary_header=False): # they didn't want the primary header fields, we parse for them and then # remove them after. packet_arrays = _load( - file, self._fields, self._converters, "variable_length", include_primary_header=True + file, + self._fields, + self._converters, + "variable_length", + include_primary_header=True, + reset_file_obj=reset_file_obj, ) # inspect the primary header and issue warning if appropriate @@ -593,7 +607,9 @@ def _get_fields_csv_file(csv_file): return fields -def _load(file, fields, converters, decoder_name, include_primary_header=False): +def _load( + file, fields, converters, decoder_name, include_primary_header=False, reset_file_obj=False +): """Decode a file-like object containing a sequence of these packets. Parameters @@ -609,6 +625,10 @@ def _load(file, fields, converters, decoder_name, include_primary_header=False): String identifying which decoder to use. include_primary_header: bool If True, provides the primary header in the output + reset_file_obj : bool + If True, leave the file object, when it is a file buffer, where it was before _load is called. + Otherwise, (default), leave the file stream pos after the read packets. + Does not apply when file is a string. Returns ------- @@ -621,6 +641,7 @@ def _load(file, fields, converters, decoder_name, include_primary_header=False): the decoder_name is not one of the allowed values """ if hasattr(file, "read"): + file_pos = file.tell() file_bytes = np.frombuffer(file.read(), "u1") else: file_bytes = np.fromfile(file, "u1") @@ -646,6 +667,8 @@ def _load(file, fields, converters, decoder_name, include_primary_header=False): field_arrays = _apply_post_byte_reoderings(field_arrays, orig_fields) field_arrays = _apply_converters(field_arrays, converters) + if hasattr(file, "read") and reset_file_obj: + file.seek(file_pos) return field_arrays diff --git a/ccsdspy/tests/test_byte_order.py b/ccsdspy/tests/test_byte_order.py index 7be7617..5778b67 100644 --- a/ccsdspy/tests/test_byte_order.py +++ b/ccsdspy/tests/test_byte_order.py @@ -6,6 +6,7 @@ See also: ccsdspy/tests/data/byte_order/byte_order_packets.py """ + import glob import itertools import os diff --git a/ccsdspy/tests/test_packet_types.py b/ccsdspy/tests/test_packet_types.py index a9cb834..2bab5c3 100644 --- a/ccsdspy/tests/test_packet_types.py +++ b/ccsdspy/tests/test_packet_types.py @@ -20,6 +20,12 @@ csv_file_4col_with_array = os.path.join(packet_def_dir, "simple_csv_4col_with_array.csv") csv_file_3col_with_array = os.path.join(packet_def_dir, "simple_csv_3col_with_array.csv") +hs_packet_dir = os.path.join(dir_path, "data", "hs") +random_binary_file = os.path.join( + hs_packet_dir, "apid001", "SSAT1_2015-180-00-00-00_2015-180-01-59-58_1_1_sim.tlm" +) +random_packet_def = os.path.join(hs_packet_dir, "apid001", "defs.csv") + def test_FixedLength_initializer_copies_field_list(): """Tests that the FixedLengthPacket initializer stores a copy of the @@ -229,3 +235,12 @@ def test_variable_length_rejects_bit_offset(): ), ] ) + + +def test_load_without_moving_file_buffer_pos(): + """Tests that load(..., reset_file_obj=True) works as intended.""" + pkts = FixedLength.from_file(random_packet_def) + with open(random_binary_file, "rb") as fp: + pos = fp.tell() + pkts.load(fp, reset_file_obj=True) + assert pos == fp.tell()