diff --git a/src/BERDecode.cpp b/src/BERDecode.cpp index 5f041d7..e4d250f 100644 --- a/src/BERDecode.cpp +++ b/src/BERDecode.cpp @@ -72,7 +72,8 @@ int IntegerType::fromBuffer(const uint8_t *buf, size_t max_len){ const uint8_t* ptr = buf + i; unsigned short tempLength = _length; - uint32_t tempVal = 0; + _knownLen = _length; + uint32_t tempVal = 0; while(tempLength > 0){ tempVal = tempVal << 8; diff --git a/src/BEREncode.cpp b/src/BEREncode.cpp index fed7d8c..cdb3281 100644 --- a/src/BEREncode.cpp +++ b/src/BEREncode.cpp @@ -74,13 +74,22 @@ int NetworkAddress::serialise(uint8_t* buf, size_t max_len){ } int IntegerType::serialise(uint8_t* buf, size_t max_len){ - int i = BER_CONTAINER::serialise(buf, max_len, 4); + int i = BER_CONTAINER::serialise(buf, max_len, _knownLen); CHECK_ENCODE_ERR(i); uint8_t *ptr = buf + i; - *ptr++ = _value >> 24 & 0xFF; - *ptr++ = _value >> 16 & 0xFF; - *ptr++ = _value >> 8 & 0xFF; + if (_knownLen >= 4) { + *ptr++ = _value >> 24 & 0xFF; + } + + if (_knownLen >= 3) { + *ptr++ = _value >> 16 & 0xFF; + } + + if (_knownLen >= 2) { + *ptr++ = _value >> 8 & 0xFF; + } + *ptr++ = _value & 0xFF; return ptr - buf; diff --git a/src/SNMPPacket.cpp b/src/SNMPPacket.cpp index b9f886c..3c22161 100644 --- a/src/SNMPPacket.cpp +++ b/src/SNMPPacket.cpp @@ -43,7 +43,7 @@ SNMP_PACKET_PARSE_ERROR SNMPPacket::parsePacket(ComplexType *structure, enum SNM case SNMPVERSION: ASSERT_ASN_STATE_TYPE(value, SNMPVERSION); - this->snmpVersionPtr = std::static_pointer_cast(value); + this->snmpVersionPtr = std::static_pointer_cast(value); this->snmpVersion = (SNMP_VERSION) this->snmpVersionPtr.get()->_value; if (this->snmpVersion >= SNMP_VERSION_MAX) { SNMP_LOGW("Invalid SNMP Version: %d\n", this->snmpVersion); @@ -154,7 +154,7 @@ bool SNMPPacket::build(){ if(this->snmpVersionPtr) this->packet->addValueToList(this->snmpVersionPtr); else - this->packet->addValueToList(std::make_shared(this->snmpVersion)); + this->packet->addValueToList(std::make_shared(this->snmpVersion)); if(this->communityStringPtr) this->packet->addValueToList(this->communityStringPtr); @@ -168,9 +168,8 @@ bool SNMPPacket::build(){ else snmpPDU->addValueToList(std::make_shared(this->requestID)); - - snmpPDU->addValueToList(std::make_shared(this->errorStatus.errorStatus)); - snmpPDU->addValueToList(std::make_shared(this->errorIndex.errorIndex)); + snmpPDU->addValueToList(std::make_shared(this->errorStatus.errorStatus)); + snmpPDU->addValueToList(std::make_shared(this->errorIndex.errorIndex)); // We need to do this dynamically incase we're building a trap, generateVarBindList is virtual auto varBindList = this->generateVarBindList(); diff --git a/src/include/BER.h b/src/include/BER.h index 93033a7..252d352 100644 --- a/src/include/BER.h +++ b/src/include/BER.h @@ -128,12 +128,23 @@ class IntegerType: public BER_CONTAINER { }; int _value = 0; + int _knownLen = 4; protected: int serialise(uint8_t* buf, size_t max_len) override; int fromBuffer(const uint8_t *buf, size_t max_len) override; }; +class ByteType: public IntegerType { + public: + ByteType(): IntegerType() { + _knownLen = 1; + }; + explicit ByteType(uint8_t value): IntegerType(value) { + _knownLen = 1; + }; +}; + class TimestampType: public IntegerType { public: TimestampType(): IntegerType(){ diff --git a/src/include/SNMPPacket.h b/src/include/SNMPPacket.h index b4a0dc7..ac01987 100644 --- a/src/include/SNMPPacket.h +++ b/src/include/SNMPPacket.h @@ -73,7 +73,7 @@ class SNMPPacket { bool reuse = false; std::shared_ptr requestIDPtr = nullptr; - std::shared_ptr snmpVersionPtr = nullptr; + std::shared_ptr snmpVersionPtr = nullptr; std::shared_ptr communityStringPtr = nullptr; snmp_request_id_t requestID = 0; diff --git a/tests/tests.cpp b/tests/tests.cpp index c62ba53..5d0af4f 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -4,11 +4,11 @@ #include "include/SNMPPacket.h" #include "include/ValueCallbacks.h" #include "include/SNMPParser.h" - #include "SNMPTrap.h" - #include +const int expected_length = 123; // expected length of GenerateTestSNMPRequestPacket() + static SNMPPacket* GenerateTestSNMPRequestPacket(){ SNMPPacket* packet = new SNMPPacket(); @@ -32,13 +32,13 @@ TEST_CASE( "Test handle failures when Encoding/Decoding", "[snmp]"){ int serialised_length = 0; SECTION( "Failed Serialisation" ){ - serialised_length = packet->serialiseInto(buffer, 132); + serialised_length = packet->serialiseInto(buffer, expected_length - 1); REQUIRE( serialised_length <= 0 ); } SECTION( "Suceed Serialisation" ){ - serialised_length = packet->serialiseInto(buffer, 133); - REQUIRE( serialised_length == 133 ); + serialised_length = packet->serialiseInto(buffer, expected_length); + REQUIRE( serialised_length == expected_length ); } uint8_t copyBuffer[500] = {0}; @@ -47,7 +47,7 @@ TEST_CASE( "Test handle failures when Encoding/Decoding", "[snmp]"){ SECTION( "Should fail to parse a buffer too small"){ SNMPPacket* readPack = new SNMPPacket(); - REQUIRE( readPack->parseFrom(buffer, 130) != SNMP_ERROR_OK ); + REQUIRE( readPack->parseFrom(buffer, expected_length - 1) != SNMP_ERROR_OK ); } SECTION( "Decoding should not modify the buffer"){ @@ -56,12 +56,12 @@ TEST_CASE( "Test handle failures when Encoding/Decoding", "[snmp]"){ SECTION( "Should be able to reparse the buffer with correct max_size"){ SNMPPacket* readPack = new SNMPPacket(); - REQUIRE( readPack->parseFrom(buffer, 133) == SNMP_ERROR_OK ); + REQUIRE( readPack->parseFrom(buffer, expected_length) == SNMP_ERROR_OK ); } /* SECTION( "Should fail to parse a corrupt buffer "){ SNMPPacket* readPacket = new SNMPPacket(); - for(int i = 25; i < 133; i+= 10){ + for(int i = 25; i < expected_length; i+= 10){ char old[10] = {0}; memcpy(old, &buffer[i], 10); long randomLong = random(); @@ -83,7 +83,7 @@ TEST_CASE( "Test Encoding/Decoding packet", "[snmp]" ) { SECTION( "Serialisation" ){ serialised_length = packet->serialiseInto(buffer, 500); - REQUIRE( serialised_length == 133 ); + REQUIRE( serialised_length == expected_length ); } // Read packet SNMPPacket* readPacket = new SNMPPacket();