From 539364846a89987ac2679988653f50332cb91d26 Mon Sep 17 00:00:00 2001 From: "madmaxoft@gmail.com" Date: Thu, 30 Aug 2012 21:06:13 +0000 Subject: Implemented 1.3.2 protocol encryption using CryptoPP, up to Client Status packet (http://wiki.vg/Protocol_FAQ step 14) git-svn-id: http://mc-server.googlecode.com/svn/trunk@808 0a769ca7-a7f5-676a-18bf-c427514a06d6 --- CryptoPP/cryptlib.cpp | 828 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 828 insertions(+) create mode 100644 CryptoPP/cryptlib.cpp (limited to 'CryptoPP/cryptlib.cpp') diff --git a/CryptoPP/cryptlib.cpp b/CryptoPP/cryptlib.cpp new file mode 100644 index 000000000..df138ddb0 --- /dev/null +++ b/CryptoPP/cryptlib.cpp @@ -0,0 +1,828 @@ +// cryptlib.cpp - written and placed in the public domain by Wei Dai + +#include "pch.h" + +#ifndef CRYPTOPP_IMPORTS + +#include "cryptlib.h" +#include "misc.h" +#include "filters.h" +#include "algparam.h" +#include "fips140.h" +#include "argnames.h" +#include "fltrimpl.h" +#include "trdlocal.h" +#include "osrng.h" + +#include + +NAMESPACE_BEGIN(CryptoPP) + +CRYPTOPP_COMPILE_ASSERT(sizeof(byte) == 1); +CRYPTOPP_COMPILE_ASSERT(sizeof(word16) == 2); +CRYPTOPP_COMPILE_ASSERT(sizeof(word32) == 4); +CRYPTOPP_COMPILE_ASSERT(sizeof(word64) == 8); +#ifdef CRYPTOPP_NATIVE_DWORD_AVAILABLE +CRYPTOPP_COMPILE_ASSERT(sizeof(dword) == 2*sizeof(word)); +#endif + +const std::string DEFAULT_CHANNEL; +const std::string AAD_CHANNEL = "AAD"; +const std::string &BufferedTransformation::NULL_CHANNEL = DEFAULT_CHANNEL; + +class NullNameValuePairs : public NameValuePairs +{ +public: + bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const {return false;} +}; + +simple_ptr s_pNullNameValuePairs(new NullNameValuePairs); +const NameValuePairs &g_nullNameValuePairs = *s_pNullNameValuePairs.m_p; + +BufferedTransformation & TheBitBucket() +{ + static BitBucket bitBucket; + return bitBucket; +} + +Algorithm::Algorithm(bool checkSelfTestStatus) +{ + if (checkSelfTestStatus && FIPS_140_2_ComplianceEnabled()) + { + if (GetPowerUpSelfTestStatus() == POWER_UP_SELF_TEST_NOT_DONE && !PowerUpSelfTestInProgressOnThisThread()) + throw SelfTestFailure("Cryptographic algorithms are disabled before the power-up self tests are performed."); + + if (GetPowerUpSelfTestStatus() == POWER_UP_SELF_TEST_FAILED) + throw SelfTestFailure("Cryptographic algorithms are disabled after a power-up self test failed."); + } +} + +void SimpleKeyingInterface::SetKey(const byte *key, size_t length, const NameValuePairs ¶ms) +{ + this->ThrowIfInvalidKeyLength(length); + this->UncheckedSetKey(key, (unsigned int)length, params); +} + +void SimpleKeyingInterface::SetKeyWithRounds(const byte *key, size_t length, int rounds) +{ + SetKey(key, length, MakeParameters(Name::Rounds(), rounds)); +} + +void SimpleKeyingInterface::SetKeyWithIV(const byte *key, size_t length, const byte *iv, size_t ivLength) +{ + SetKey(key, length, MakeParameters(Name::IV(), ConstByteArrayParameter(iv, ivLength))); +} + +void SimpleKeyingInterface::ThrowIfInvalidKeyLength(size_t length) +{ + if (!IsValidKeyLength(length)) + throw InvalidKeyLength(GetAlgorithm().AlgorithmName(), length); +} + +void SimpleKeyingInterface::ThrowIfResynchronizable() +{ + if (IsResynchronizable()) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": this object requires an IV"); +} + +void SimpleKeyingInterface::ThrowIfInvalidIV(const byte *iv) +{ + if (!iv && IVRequirement() == UNPREDICTABLE_RANDOM_IV) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": this object cannot use a null IV"); +} + +size_t SimpleKeyingInterface::ThrowIfInvalidIVLength(int size) +{ + if (size < 0) + return IVSize(); + else if ((size_t)size < MinIVLength()) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": IV length " + IntToString(size) + " is less than the minimum of " + IntToString(MinIVLength())); + else if ((size_t)size > MaxIVLength()) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": IV length " + IntToString(size) + " exceeds the maximum of " + IntToString(MaxIVLength())); + else + return size; +} + +const byte * SimpleKeyingInterface::GetIVAndThrowIfInvalid(const NameValuePairs ¶ms, size_t &size) +{ + ConstByteArrayParameter ivWithLength; + const byte *iv; + bool found = false; + + try {found = params.GetValue(Name::IV(), ivWithLength);} + catch (const NameValuePairs::ValueTypeMismatch &) {} + + if (found) + { + iv = ivWithLength.begin(); + ThrowIfInvalidIV(iv); + size = ThrowIfInvalidIVLength((int)ivWithLength.size()); + return iv; + } + else if (params.GetValue(Name::IV(), iv)) + { + ThrowIfInvalidIV(iv); + size = IVSize(); + return iv; + } + else + { + ThrowIfResynchronizable(); + size = 0; + return NULL; + } +} + +void SimpleKeyingInterface::GetNextIV(RandomNumberGenerator &rng, byte *IV) +{ + rng.GenerateBlock(IV, IVSize()); +} + +size_t BlockTransformation::AdvancedProcessBlocks(const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const +{ + size_t blockSize = BlockSize(); + size_t inIncrement = (flags & (BT_InBlockIsCounter|BT_DontIncrementInOutPointers)) ? 0 : blockSize; + size_t xorIncrement = xorBlocks ? blockSize : 0; + size_t outIncrement = (flags & BT_DontIncrementInOutPointers) ? 0 : blockSize; + + if (flags & BT_ReverseDirection) + { + assert(length % blockSize == 0); + inBlocks += length - blockSize; + xorBlocks += length - blockSize; + outBlocks += length - blockSize; + inIncrement = 0-inIncrement; + xorIncrement = 0-xorIncrement; + outIncrement = 0-outIncrement; + } + + while (length >= blockSize) + { + if (flags & BT_XorInput) + { + xorbuf(outBlocks, xorBlocks, inBlocks, blockSize); + ProcessBlock(outBlocks); + } + else + ProcessAndXorBlock(inBlocks, xorBlocks, outBlocks); + if (flags & BT_InBlockIsCounter) + const_cast(inBlocks)[blockSize-1]++; + inBlocks += inIncrement; + outBlocks += outIncrement; + xorBlocks += xorIncrement; + length -= blockSize; + } + + return length; +} + +unsigned int BlockTransformation::OptimalDataAlignment() const +{ + return GetAlignmentOf(); +} + +unsigned int StreamTransformation::OptimalDataAlignment() const +{ + return GetAlignmentOf(); +} + +unsigned int HashTransformation::OptimalDataAlignment() const +{ + return GetAlignmentOf(); +} + +void StreamTransformation::ProcessLastBlock(byte *outString, const byte *inString, size_t length) +{ + assert(MinLastBlockSize() == 0); // this function should be overriden otherwise + + if (length == MandatoryBlockSize()) + ProcessData(outString, inString, length); + else if (length != 0) + throw NotImplemented(AlgorithmName() + ": this object does't support a special last block"); +} + +void AuthenticatedSymmetricCipher::SpecifyDataLengths(lword headerLength, lword messageLength, lword footerLength) +{ + if (headerLength > MaxHeaderLength()) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": header length " + IntToString(headerLength) + " exceeds the maximum of " + IntToString(MaxHeaderLength())); + + if (messageLength > MaxMessageLength()) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": message length " + IntToString(messageLength) + " exceeds the maximum of " + IntToString(MaxMessageLength())); + + if (footerLength > MaxFooterLength()) + throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": footer length " + IntToString(footerLength) + " exceeds the maximum of " + IntToString(MaxFooterLength())); + + UncheckedSpecifyDataLengths(headerLength, messageLength, footerLength); +} + +void AuthenticatedSymmetricCipher::EncryptAndAuthenticate(byte *ciphertext, byte *mac, size_t macSize, const byte *iv, int ivLength, const byte *header, size_t headerLength, const byte *message, size_t messageLength) +{ + Resynchronize(iv, ivLength); + SpecifyDataLengths(headerLength, messageLength); + Update(header, headerLength); + ProcessString(ciphertext, message, messageLength); + TruncatedFinal(mac, macSize); +} + +bool AuthenticatedSymmetricCipher::DecryptAndVerify(byte *message, const byte *mac, size_t macLength, const byte *iv, int ivLength, const byte *header, size_t headerLength, const byte *ciphertext, size_t ciphertextLength) +{ + Resynchronize(iv, ivLength); + SpecifyDataLengths(headerLength, ciphertextLength); + Update(header, headerLength); + ProcessString(message, ciphertext, ciphertextLength); + return TruncatedVerify(mac, macLength); +} + +unsigned int RandomNumberGenerator::GenerateBit() +{ + return GenerateByte() & 1; +} + +byte RandomNumberGenerator::GenerateByte() +{ + byte b; + GenerateBlock(&b, 1); + return b; +} + +word32 RandomNumberGenerator::GenerateWord32(word32 min, word32 max) +{ + word32 range = max-min; + const int maxBits = BitPrecision(range); + + word32 value; + + do + { + GenerateBlock((byte *)&value, sizeof(value)); + value = Crop(value, maxBits); + } while (value > range); + + return value+min; +} + +void RandomNumberGenerator::GenerateBlock(byte *output, size_t size) +{ + ArraySink s(output, size); + GenerateIntoBufferedTransformation(s, DEFAULT_CHANNEL, size); +} + +void RandomNumberGenerator::DiscardBytes(size_t n) +{ + GenerateIntoBufferedTransformation(TheBitBucket(), DEFAULT_CHANNEL, n); +} + +void RandomNumberGenerator::GenerateIntoBufferedTransformation(BufferedTransformation &target, const std::string &channel, lword length) +{ + FixedSizeSecBlock buffer; + while (length) + { + size_t len = UnsignedMin(buffer.size(), length); + GenerateBlock(buffer, len); + target.ChannelPut(channel, buffer, len); + length -= len; + } +} + +//! see NullRNG() +class ClassNullRNG : public RandomNumberGenerator +{ +public: + std::string AlgorithmName() const {return "NullRNG";} + void GenerateBlock(byte *output, size_t size) {throw NotImplemented("NullRNG: NullRNG should only be passed to functions that don't need to generate random bytes");} +}; + +RandomNumberGenerator & NullRNG() +{ + static ClassNullRNG s_nullRNG; + return s_nullRNG; +} + +bool HashTransformation::TruncatedVerify(const byte *digestIn, size_t digestLength) +{ + ThrowIfInvalidTruncatedSize(digestLength); + SecByteBlock digest(digestLength); + TruncatedFinal(digest, digestLength); + return VerifyBufsEqual(digest, digestIn, digestLength); +} + +void HashTransformation::ThrowIfInvalidTruncatedSize(size_t size) const +{ + if (size > DigestSize()) + throw InvalidArgument("HashTransformation: can't truncate a " + IntToString(DigestSize()) + " byte digest to " + IntToString(size) + " bytes"); +} + +unsigned int BufferedTransformation::GetMaxWaitObjectCount() const +{ + const BufferedTransformation *t = AttachedTransformation(); + return t ? t->GetMaxWaitObjectCount() : 0; +} + +void BufferedTransformation::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack) +{ + BufferedTransformation *t = AttachedTransformation(); + if (t) + t->GetWaitObjects(container, callStack); // reduce clutter by not adding to stack here +} + +void BufferedTransformation::Initialize(const NameValuePairs ¶meters, int propagation) +{ + assert(!AttachedTransformation()); + IsolatedInitialize(parameters); +} + +bool BufferedTransformation::Flush(bool hardFlush, int propagation, bool blocking) +{ + assert(!AttachedTransformation()); + return IsolatedFlush(hardFlush, blocking); +} + +bool BufferedTransformation::MessageSeriesEnd(int propagation, bool blocking) +{ + assert(!AttachedTransformation()); + return IsolatedMessageSeriesEnd(blocking); +} + +byte * BufferedTransformation::ChannelCreatePutSpace(const std::string &channel, size_t &size) +{ + if (channel.empty()) + return CreatePutSpace(size); + else + throw NoChannelSupport(AlgorithmName()); +} + +size_t BufferedTransformation::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking) +{ + if (channel.empty()) + return Put2(begin, length, messageEnd, blocking); + else + throw NoChannelSupport(AlgorithmName()); +} + +size_t BufferedTransformation::ChannelPutModifiable2(const std::string &channel, byte *begin, size_t length, int messageEnd, bool blocking) +{ + if (channel.empty()) + return PutModifiable2(begin, length, messageEnd, blocking); + else + return ChannelPut2(channel, begin, length, messageEnd, blocking); +} + +bool BufferedTransformation::ChannelFlush(const std::string &channel, bool completeFlush, int propagation, bool blocking) +{ + if (channel.empty()) + return Flush(completeFlush, propagation, blocking); + else + throw NoChannelSupport(AlgorithmName()); +} + +bool BufferedTransformation::ChannelMessageSeriesEnd(const std::string &channel, int propagation, bool blocking) +{ + if (channel.empty()) + return MessageSeriesEnd(propagation, blocking); + else + throw NoChannelSupport(AlgorithmName()); +} + +lword BufferedTransformation::MaxRetrievable() const +{ + if (AttachedTransformation()) + return AttachedTransformation()->MaxRetrievable(); + else + return CopyTo(TheBitBucket()); +} + +bool BufferedTransformation::AnyRetrievable() const +{ + if (AttachedTransformation()) + return AttachedTransformation()->AnyRetrievable(); + else + { + byte b; + return Peek(b) != 0; + } +} + +size_t BufferedTransformation::Get(byte &outByte) +{ + if (AttachedTransformation()) + return AttachedTransformation()->Get(outByte); + else + return Get(&outByte, 1); +} + +size_t BufferedTransformation::Get(byte *outString, size_t getMax) +{ + if (AttachedTransformation()) + return AttachedTransformation()->Get(outString, getMax); + else + { + ArraySink arraySink(outString, getMax); + return (size_t)TransferTo(arraySink, getMax); + } +} + +size_t BufferedTransformation::Peek(byte &outByte) const +{ + if (AttachedTransformation()) + return AttachedTransformation()->Peek(outByte); + else + return Peek(&outByte, 1); +} + +size_t BufferedTransformation::Peek(byte *outString, size_t peekMax) const +{ + if (AttachedTransformation()) + return AttachedTransformation()->Peek(outString, peekMax); + else + { + ArraySink arraySink(outString, peekMax); + return (size_t)CopyTo(arraySink, peekMax); + } +} + +lword BufferedTransformation::Skip(lword skipMax) +{ + if (AttachedTransformation()) + return AttachedTransformation()->Skip(skipMax); + else + return TransferTo(TheBitBucket(), skipMax); +} + +lword BufferedTransformation::TotalBytesRetrievable() const +{ + if (AttachedTransformation()) + return AttachedTransformation()->TotalBytesRetrievable(); + else + return MaxRetrievable(); +} + +unsigned int BufferedTransformation::NumberOfMessages() const +{ + if (AttachedTransformation()) + return AttachedTransformation()->NumberOfMessages(); + else + return CopyMessagesTo(TheBitBucket()); +} + +bool BufferedTransformation::AnyMessages() const +{ + if (AttachedTransformation()) + return AttachedTransformation()->AnyMessages(); + else + return NumberOfMessages() != 0; +} + +bool BufferedTransformation::GetNextMessage() +{ + if (AttachedTransformation()) + return AttachedTransformation()->GetNextMessage(); + else + { + assert(!AnyMessages()); + return false; + } +} + +unsigned int BufferedTransformation::SkipMessages(unsigned int count) +{ + if (AttachedTransformation()) + return AttachedTransformation()->SkipMessages(count); + else + return TransferMessagesTo(TheBitBucket(), count); +} + +size_t BufferedTransformation::TransferMessagesTo2(BufferedTransformation &target, unsigned int &messageCount, const std::string &channel, bool blocking) +{ + if (AttachedTransformation()) + return AttachedTransformation()->TransferMessagesTo2(target, messageCount, channel, blocking); + else + { + unsigned int maxMessages = messageCount; + for (messageCount=0; messageCount < maxMessages && AnyMessages(); messageCount++) + { + size_t blockedBytes; + lword transferredBytes; + + while (AnyRetrievable()) + { + transferredBytes = LWORD_MAX; + blockedBytes = TransferTo2(target, transferredBytes, channel, blocking); + if (blockedBytes > 0) + return blockedBytes; + } + + if (target.ChannelMessageEnd(channel, GetAutoSignalPropagation(), blocking)) + return 1; + + bool result = GetNextMessage(); + assert(result); + } + return 0; + } +} + +unsigned int BufferedTransformation::CopyMessagesTo(BufferedTransformation &target, unsigned int count, const std::string &channel) const +{ + if (AttachedTransformation()) + return AttachedTransformation()->CopyMessagesTo(target, count, channel); + else + return 0; +} + +void BufferedTransformation::SkipAll() +{ + if (AttachedTransformation()) + AttachedTransformation()->SkipAll(); + else + { + while (SkipMessages()) {} + while (Skip()) {} + } +} + +size_t BufferedTransformation::TransferAllTo2(BufferedTransformation &target, const std::string &channel, bool blocking) +{ + if (AttachedTransformation()) + return AttachedTransformation()->TransferAllTo2(target, channel, blocking); + else + { + assert(!NumberOfMessageSeries()); + + unsigned int messageCount; + do + { + messageCount = UINT_MAX; + size_t blockedBytes = TransferMessagesTo2(target, messageCount, channel, blocking); + if (blockedBytes) + return blockedBytes; + } + while (messageCount != 0); + + lword byteCount; + do + { + byteCount = ULONG_MAX; + size_t blockedBytes = TransferTo2(target, byteCount, channel, blocking); + if (blockedBytes) + return blockedBytes; + } + while (byteCount != 0); + + return 0; + } +} + +void BufferedTransformation::CopyAllTo(BufferedTransformation &target, const std::string &channel) const +{ + if (AttachedTransformation()) + AttachedTransformation()->CopyAllTo(target, channel); + else + { + assert(!NumberOfMessageSeries()); + while (CopyMessagesTo(target, UINT_MAX, channel)) {} + } +} + +void BufferedTransformation::SetRetrievalChannel(const std::string &channel) +{ + if (AttachedTransformation()) + AttachedTransformation()->SetRetrievalChannel(channel); +} + +size_t BufferedTransformation::ChannelPutWord16(const std::string &channel, word16 value, ByteOrder order, bool blocking) +{ + PutWord(false, order, m_buf, value); + return ChannelPut(channel, m_buf, 2, blocking); +} + +size_t BufferedTransformation::ChannelPutWord32(const std::string &channel, word32 value, ByteOrder order, bool blocking) +{ + PutWord(false, order, m_buf, value); + return ChannelPut(channel, m_buf, 4, blocking); +} + +size_t BufferedTransformation::PutWord16(word16 value, ByteOrder order, bool blocking) +{ + return ChannelPutWord16(DEFAULT_CHANNEL, value, order, blocking); +} + +size_t BufferedTransformation::PutWord32(word32 value, ByteOrder order, bool blocking) +{ + return ChannelPutWord32(DEFAULT_CHANNEL, value, order, blocking); +} + +size_t BufferedTransformation::PeekWord16(word16 &value, ByteOrder order) const +{ + byte buf[2] = {0, 0}; + size_t len = Peek(buf, 2); + + if (order) + value = (buf[0] << 8) | buf[1]; + else + value = (buf[1] << 8) | buf[0]; + + return len; +} + +size_t BufferedTransformation::PeekWord32(word32 &value, ByteOrder order) const +{ + byte buf[4] = {0, 0, 0, 0}; + size_t len = Peek(buf, 4); + + if (order) + value = (buf[0] << 24) | (buf[1] << 16) | (buf[2] << 8) | buf [3]; + else + value = (buf[3] << 24) | (buf[2] << 16) | (buf[1] << 8) | buf [0]; + + return len; +} + +size_t BufferedTransformation::GetWord16(word16 &value, ByteOrder order) +{ + return (size_t)Skip(PeekWord16(value, order)); +} + +size_t BufferedTransformation::GetWord32(word32 &value, ByteOrder order) +{ + return (size_t)Skip(PeekWord32(value, order)); +} + +void BufferedTransformation::Attach(BufferedTransformation *newOut) +{ + if (AttachedTransformation() && AttachedTransformation()->Attachable()) + AttachedTransformation()->Attach(newOut); + else + Detach(newOut); +} + +void GeneratableCryptoMaterial::GenerateRandomWithKeySize(RandomNumberGenerator &rng, unsigned int keySize) +{ + GenerateRandom(rng, MakeParameters("KeySize", (int)keySize)); +} + +class PK_DefaultEncryptionFilter : public Unflushable +{ +public: + PK_DefaultEncryptionFilter(RandomNumberGenerator &rng, const PK_Encryptor &encryptor, BufferedTransformation *attachment, const NameValuePairs ¶meters) + : m_rng(rng), m_encryptor(encryptor), m_parameters(parameters) + { + Detach(attachment); + } + + size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking) + { + FILTER_BEGIN; + m_plaintextQueue.Put(inString, length); + + if (messageEnd) + { + { + size_t plaintextLength; + if (!SafeConvert(m_plaintextQueue.CurrentSize(), plaintextLength)) + throw InvalidArgument("PK_DefaultEncryptionFilter: plaintext too long"); + size_t ciphertextLength = m_encryptor.CiphertextLength(plaintextLength); + + SecByteBlock plaintext(plaintextLength); + m_plaintextQueue.Get(plaintext, plaintextLength); + m_ciphertext.resize(ciphertextLength); + m_encryptor.Encrypt(m_rng, plaintext, plaintextLength, m_ciphertext, m_parameters); + } + + FILTER_OUTPUT(1, m_ciphertext, m_ciphertext.size(), messageEnd); + } + FILTER_END_NO_MESSAGE_END; + } + + RandomNumberGenerator &m_rng; + const PK_Encryptor &m_encryptor; + const NameValuePairs &m_parameters; + ByteQueue m_plaintextQueue; + SecByteBlock m_ciphertext; +}; + +BufferedTransformation * PK_Encryptor::CreateEncryptionFilter(RandomNumberGenerator &rng, BufferedTransformation *attachment, const NameValuePairs ¶meters) const +{ + return new PK_DefaultEncryptionFilter(rng, *this, attachment, parameters); +} + +class PK_DefaultDecryptionFilter : public Unflushable +{ +public: + PK_DefaultDecryptionFilter(RandomNumberGenerator &rng, const PK_Decryptor &decryptor, BufferedTransformation *attachment, const NameValuePairs ¶meters) + : m_rng(rng), m_decryptor(decryptor), m_parameters(parameters) + { + Detach(attachment); + } + + size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking) + { + FILTER_BEGIN; + m_ciphertextQueue.Put(inString, length); + + if (messageEnd) + { + { + size_t ciphertextLength; + if (!SafeConvert(m_ciphertextQueue.CurrentSize(), ciphertextLength)) + throw InvalidArgument("PK_DefaultDecryptionFilter: ciphertext too long"); + size_t maxPlaintextLength = m_decryptor.MaxPlaintextLength(ciphertextLength); + + SecByteBlock ciphertext(ciphertextLength); + m_ciphertextQueue.Get(ciphertext, ciphertextLength); + m_plaintext.resize(maxPlaintextLength); + m_result = m_decryptor.Decrypt(m_rng, ciphertext, ciphertextLength, m_plaintext, m_parameters); + if (!m_result.isValidCoding) + throw InvalidCiphertext(m_decryptor.AlgorithmName() + ": invalid ciphertext"); + } + + FILTER_OUTPUT(1, m_plaintext, m_result.messageLength, messageEnd); + } + FILTER_END_NO_MESSAGE_END; + } + + RandomNumberGenerator &m_rng; + const PK_Decryptor &m_decryptor; + const NameValuePairs &m_parameters; + ByteQueue m_ciphertextQueue; + SecByteBlock m_plaintext; + DecodingResult m_result; +}; + +BufferedTransformation * PK_Decryptor::CreateDecryptionFilter(RandomNumberGenerator &rng, BufferedTransformation *attachment, const NameValuePairs ¶meters) const +{ + return new PK_DefaultDecryptionFilter(rng, *this, attachment, parameters); +} + +size_t PK_Signer::Sign(RandomNumberGenerator &rng, PK_MessageAccumulator *messageAccumulator, byte *signature) const +{ + std::auto_ptr m(messageAccumulator); + return SignAndRestart(rng, *m, signature, false); +} + +size_t PK_Signer::SignMessage(RandomNumberGenerator &rng, const byte *message, size_t messageLen, byte *signature) const +{ + std::auto_ptr m(NewSignatureAccumulator(rng)); + m->Update(message, messageLen); + return SignAndRestart(rng, *m, signature, false); +} + +size_t PK_Signer::SignMessageWithRecovery(RandomNumberGenerator &rng, const byte *recoverableMessage, size_t recoverableMessageLength, + const byte *nonrecoverableMessage, size_t nonrecoverableMessageLength, byte *signature) const +{ + std::auto_ptr m(NewSignatureAccumulator(rng)); + InputRecoverableMessage(*m, recoverableMessage, recoverableMessageLength); + m->Update(nonrecoverableMessage, nonrecoverableMessageLength); + return SignAndRestart(rng, *m, signature, false); +} + +bool PK_Verifier::Verify(PK_MessageAccumulator *messageAccumulator) const +{ + std::auto_ptr m(messageAccumulator); + return VerifyAndRestart(*m); +} + +bool PK_Verifier::VerifyMessage(const byte *message, size_t messageLen, const byte *signature, size_t signatureLength) const +{ + std::auto_ptr m(NewVerificationAccumulator()); + InputSignature(*m, signature, signatureLength); + m->Update(message, messageLen); + return VerifyAndRestart(*m); +} + +DecodingResult PK_Verifier::Recover(byte *recoveredMessage, PK_MessageAccumulator *messageAccumulator) const +{ + std::auto_ptr m(messageAccumulator); + return RecoverAndRestart(recoveredMessage, *m); +} + +DecodingResult PK_Verifier::RecoverMessage(byte *recoveredMessage, + const byte *nonrecoverableMessage, size_t nonrecoverableMessageLength, + const byte *signature, size_t signatureLength) const +{ + std::auto_ptr m(NewVerificationAccumulator()); + InputSignature(*m, signature, signatureLength); + m->Update(nonrecoverableMessage, nonrecoverableMessageLength); + return RecoverAndRestart(recoveredMessage, *m); +} + +void SimpleKeyAgreementDomain::GenerateKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const +{ + GeneratePrivateKey(rng, privateKey); + GeneratePublicKey(rng, privateKey, publicKey); +} + +void AuthenticatedKeyAgreementDomain::GenerateStaticKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const +{ + GenerateStaticPrivateKey(rng, privateKey); + GenerateStaticPublicKey(rng, privateKey, publicKey); +} + +void AuthenticatedKeyAgreementDomain::GenerateEphemeralKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const +{ + GenerateEphemeralPrivateKey(rng, privateKey); + GenerateEphemeralPublicKey(rng, privateKey, publicKey); +} + +NAMESPACE_END + +#endif -- cgit v1.2.3