MaxDataReader more generic and precise

This commit is contained in:
Klaas van Schelven
2024-04-25 13:58:26 +02:00
parent 387234b847
commit a19f1ff712
2 changed files with 24 additions and 11 deletions

View File

@@ -98,19 +98,15 @@ class MaxDataReader:
def read(self, size=None):
if size is None:
raise ValueError("MaxDataReader.read() - size must be specified")
return self.read(self.max_length - self.bytes_read + 1) # +1 to trigger the max length check
# Note: we raise the error when an attempt is made to read to much data. In theory/principle this means that we
# could be too strict, because we would complain before the actual problem had occurred, and the downstream read
# may actually return something much smaller than what we request.
# In practice [1] this is a rounding error [2] max sizes are usually integer multiples of our chunk size.
# (this tool is meant to be used in some chunked setting)
self.bytes_read += size
result = self.stream.read(size)
self.bytes_read += len(result)
if self.bytes_read > self.max_length:
raise ValueError("Max length exceeded")
raise ValueError("Max length (%s) exceeded" % self.max_length)
return self.stream.read(size)
return result
class MaxDataWriter:

View File

@@ -257,8 +257,25 @@ class StreamsTestCase(RegularTestCase):
for i in range(25):
self.assertEquals(b"hellohello", reader.read(10))
with self.assertRaises(ValueError):
self.assertEquals(b"hellohello", reader.read(10))
with self.assertRaises(ValueError) as e:
reader.read(10)
self.assertEquals("Max length (250) exceeded", str(e.exception))
def test_max_data_reader_none_ok(self):
stream = io.BytesIO(b"hello" * 10)
reader = MaxDataReader(stream, 250)
self.assertEquals(b"hello" * 10, reader.read(None))
def test_max_data_reader_none_fail(self):
stream = io.BytesIO(b"hello" * 100)
reader = MaxDataReader(stream, 250)
with self.assertRaises(ValueError) as e:
reader.read(None)
self.assertEquals("Max length (250) exceeded", str(e.exception))
def test_max_data_writer(self):
stream = io.BytesIO()