1
0
mirror of https://github.com/SoftEtherVPN/SoftEtherVPN.git synced 2025-04-03 18:00:08 +03:00
SoftEtherVPN/src/Cedar/Proto_WireGuard.c
Davide Beatrici dd1eebdbed Cedar: Implement support for WireGuard
Please note that the implementation is not 100% conformant to the protocol whitepaper (https://www.wireguard.com/papers/wireguard.pdf).
More specifically: all peers are expected to send a handshake initiation once the current keypair is about to expire or is expired.
I decided not to do that because our implementation is meant to act as a server only. A true WireGuard peer acts, instead, as both a client and a server.
Once the keypair is expired, we immediately delete the session.

The cookie mechanism can be implemented in future.

As for authentication: unfortunately using the already existing methods is not possible due to the protocol not providing a way to send strings to a peer.
That's because WireGuard doesn't have a concept of "users": it identifies a peer through the public key, which is determined using the source address.
As a solution, this commit adds a special authentication method: once we receive the handshake initiation message and decrypt the peer's public key, we check whether it's in the allowed key list.
If it is, we retrieve the associated Virtual Hub and user; if the hub exists and the user is in it, the authentication is successful.

The allowed key list is stored in the configuration file like this:

declare WireGuardKeyList
{
	declare 96oA7iMvjn7oXiG3ghBDPaSUytT75uXceLV+Fx3XMlM=
	{
		string Hub DEFAULT
		string User user
	}
}
2021-03-01 02:49:59 +01:00

1089 lines
26 KiB
C

#include "CedarPch.h"
#include <blake2.h>
const PROTO_IMPL *WgsGetProtoImpl()
{
static const PROTO_IMPL impl =
{
WgsName,
WgsOptions,
WgsOptionStringValue,
WgsInit,
WgsFree,
WgsIsPacketForMe,
NULL,
WgsProcessDatagrams
};
return &impl;
}
const char *WgsName()
{
return "WireGuard";
}
const PROTO_OPTION *WgsOptions()
{
static const PROTO_OPTION options[] =
{
{ .Name = "PresharedKey", .Type = PROTO_OPTION_STRING, .String = NULL},
{ .Name = "PrivateKey", .Type = PROTO_OPTION_STRING, .String = NULL },
{ .Name = NULL, .Type = PROTO_OPTION_UNKNOWN }
};
return options;
}
char *WgsOptionStringValue(const char *name)
{
if (name == NULL)
{
return NULL;
}
if (StrCmp(name, "PresharedKey") == 0 || StrCmp(name, "PrivateKey") == 0)
{
unsigned char buf[WG_KEY_SIZE];
const UINT size = sodium_base64_ENCODED_LEN(sizeof(buf), sodium_base64_VARIANT_ORIGINAL);
char *str = Malloc(size);
Rand(buf, sizeof(buf));
sodium_bin2base64(str, size, buf, sizeof(buf), sodium_base64_VARIANT_ORIGINAL);
Zero(buf, sizeof(buf));
return str;
}
return NULL;
}
bool WgsInit(void **param, const LIST *options, CEDAR *cedar, INTERRUPT_MANAGER *im, SOCK_EVENT *se, const char *cipher, const char *hostname)
{
UINT i;
WG_SERVER *server;
if (param == NULL || options == NULL || cedar == NULL || im == NULL || se == NULL)
{
return false;
}
Debug("WgsInit(): cipher: %s, hostname: %s\n", cipher, hostname);
server = ZeroMalloc(sizeof(WG_SERVER));
for (i = 0; i < LIST_NUM(options); ++i)
{
const PROTO_OPTION *option = LIST_DATA(options, i);
if (StrCmp(option->Name, "PresharedKey") == 0)
{
if (IsEmptyStr(option->String) == false)
{
sodium_base642bin(server->PresharedKey, sizeof(server->PresharedKey), option->String, StrLen(option->String), NULL, NULL, NULL, sodium_base64_VARIANT_ORIGINAL);
}
}
else if (StrCmp(option->Name, "PrivateKey") == 0)
{
sodium_base642bin(server->StaticPrivate, sizeof(server->StaticPrivate), option->String, StrLen(option->String), NULL, NULL, NULL, sodium_base64_VARIANT_ORIGINAL);
crypto_scalarmult_curve25519_base(server->StaticPublic, server->StaticPrivate);
}
}
server->Cedar = cedar;
server->SockEvent = se;
server->InterruptManager = im;
blake2s(server->HandshakeInitChainingKey, sizeof(server->HandshakeInitChainingKey), WG_CONSTRUCTION, StrLen(WG_CONSTRUCTION), NULL, 0);
blake2s_state b2s_state;
blake2s_init(&b2s_state, sizeof(server->HandshakeInitHash));
blake2s_update(&b2s_state, server->HandshakeInitChainingKey, sizeof(server->HandshakeInitChainingKey));
blake2s_update(&b2s_state, WG_IDENTIFIER, StrLen(WG_IDENTIFIER));
blake2s_final(&b2s_state, server->HandshakeInitHash, sizeof(server->HandshakeInitHash));
server->CreationTime = Tick64();
AddInterrupt(im, server->CreationTime + WG_INITIATION_GIVEUP);
*param = server;
return true;
}
void WgsFree(void *param)
{
WG_SERVER *server = param;
WG_SESSION *session;
WG_KEYPAIRS *keypairs;
if (server == NULL)
{
return;
}
session = &server->Session;
keypairs = &session->Keypairs;
FreeIPC(session->IPC);
Zero(keypairs->Current, sizeof(WG_KEYPAIR));
Free(keypairs->Current);
Zero(keypairs->Next, sizeof(WG_KEYPAIR));
Free(keypairs->Next);
Zero(keypairs->Previous, sizeof(WG_KEYPAIR));
Free(keypairs->Previous);
Zero(server, sizeof(WG_SERVER));
Free(server);
}
bool WgsIsPacketForMe(const PROTO_MODE mode, const void *data, const UINT size)
{
if (mode != PROTO_MODE_UDP)
{
return false;
}
return WgsDetectMessageType(data, size);
}
bool WgsProcessDatagrams(void *param, LIST *in, LIST *out)
{
UINT i;
WG_SERVER *server = param;
WG_SESSION *session;
WG_KEYPAIRS *keypairs;
if (server == NULL || in == NULL || out == NULL)
{
return false;
}
server->Now = Tick64();
session = &server->Session;
keypairs = &session->Keypairs;
if (keypairs->Current != NULL)
{
const WG_KEYPAIR *current = keypairs->Current;
if (server->Now - current->CreationTime >= WG_REJECT_AFTER_TIME)
{
WgsLog(server, "LW_KEYPAIR_EXPIRED", current->IndexRemote, current->IndexLocal);
return false;
}
}
else if (server->Now - server->CreationTime >= WG_INITIATION_GIVEUP)
{
Debug("WgsProcessDatagrams(): current keypair not present, giving up!\n");
return false;
}
if (keypairs->Previous != NULL)
{
WG_KEYPAIR *previous = keypairs->Previous;
if (server->Now - previous->CreationTime >= WG_REJECT_AFTER_TIME)
{
Debug("WgsProcessDatagrams(): deleting keypair: %x -> %x\n", previous->IndexRemote, previous->IndexLocal);
Zero(previous, sizeof(WG_KEYPAIR));
Free(previous);
keypairs->Previous = NULL;
}
}
for (i = 0; i < LIST_NUM(in); ++i)
{
const UDPPACKET *packet = LIST_DATA(in, i);
const UINT size = packet->Size;
void *data = packet->Data;
const WG_MSG_TYPE message_type = WgsDetectMessageType(data, size);
switch (message_type)
{
case WG_MSG_HANDSHAKE_INIT:
{
WG_KEYPAIR *keypair;
UDPPACKET *udp_reply;
WG_HANDSHAKE_REPLY *reply;
BYTE ephemeral_remote[WG_KEY_SIZE];
if (session->LastInitiationReceived + 1000 / WG_MAX_INITIATIONS_PER_SECOND > server->Now)
{
WgsLog(server, "LW_FLOOD_ATTACK");
return false;
}
session->LastInitiationReceived = server->Now;
keypair = WgsProcessHandshakeInit(server, data, ephemeral_remote);
if (keypair == NULL)
{
Debug("WgsProcessDatagrams(): WgsProcessHandshakeInit() failed!\n");
Zero(ephemeral_remote, sizeof(ephemeral_remote));
return false;
}
reply = WgsCreateHandshakeReply(server, keypair, ephemeral_remote);
Zero(ephemeral_remote, sizeof(ephemeral_remote));
if (reply == NULL)
{
Debug("WgsProcessDatagrams(): WgsCreateHandshakeReply() failed!\n");
Zero(keypair, sizeof(WG_KEYPAIR));
Free(keypair);
return false;
}
Copy(&session->IPLocal, &packet->DstIP, sizeof(session->IPLocal));
Copy(&session->IPRemote, &packet->SrcIP, sizeof(session->IPRemote));
session->PortLocal = packet->DestPort;
session->PortRemote = packet->SrcPort;
udp_reply = NewUdpPacket(&session->IPLocal, session->PortLocal, &session->IPRemote, session->PortRemote, reply, sizeof(WG_HANDSHAKE_REPLY));
Add(out, udp_reply);
AddInterrupt(server->InterruptManager, keypair->CreationTime + WG_REJECT_AFTER_TIME);
break;
}
case WG_MSG_HANDSHAKE_COOKIE:
// TODO: reply to message.
continue;
case WG_MSG_TRANSPORT_DATA:
if (keypairs->Current == NULL)
{
continue;
}
if (WgsProcessTransportData(server, data, size) == false)
{
Debug("WgsProcessDatagrams(): WgsProcessTransportData() failed!\n");
return false;
}
session->LastDataReceived = server->Now;
break;
default:
Debug("WgsProcessDatagrams(): unrecognized packet type %u\n", message_type);
return false;
}
}
if (session->IPC == NULL)
{
return true;
}
if (IsIPCConnected(session->IPC) == false)
{
WgsLog(server, "LW_HUB_DISCONNECT");
return false;
}
IPCProcessL3Events(session->IPC);
while (true)
{
UDPPACKET *udp;
UINT final_size = 0;
WG_TRANSPORT_DATA *data;
BLOCK *block = IPCRecvIPv4(session->IPC);
if (block == NULL)
{
break;
}
data = WgsCreateTransportData(server, block->Buf, block->Size, &final_size);
FreeBlock(block);
if (data == NULL)
{
continue;
}
udp = NewUdpPacket(&session->IPLocal, session->PortLocal, &session->IPRemote, session->PortRemote, data, final_size);
Add(out, udp);
}
if (LIST_NUM(out) > 0)
{
session->LastDataSent = server->Now;
}
else if (session->LastDataReceived >= session->LastDataSent)
{
if (server->Now - session->LastDataSent >= WG_KEEPALIVE_TIMEOUT)
{
UINT final_size = 0;
WG_TRANSPORT_DATA *data = WgsCreateTransportData(server, NULL, 0, &final_size);
UDPPACKET *udp = NewUdpPacket(&session->IPLocal, session->PortLocal, &session->IPRemote, session->PortRemote, data, final_size);
Add(out, udp);
Debug("WgsProcessDatagrams(): sending keepalive packet\n");
session->LastDataSent = server->Now;
// Schedule next keepalive.
AddInterrupt(server->InterruptManager, server->Now + WG_KEEPALIVE_TIMEOUT);
}
}
return true;
}
void WgsLog(const WG_SERVER *server, const char *name, ...)
{
wchar_t message[MAX_SIZE * 2];
const WG_SESSION *session;
UINT current_len;
va_list args;
if (server == NULL)
{
return;
}
session = &server->Session;
UniFormat(message, sizeof(message), _UU("LW_PREFIX_SESSION"), &session->IPRemote, session->PortRemote, &session->IPLocal, session->PortLocal);
current_len = UniStrLen(message);
va_start(args, name);
UniFormatArgs(message + current_len, sizeof(message) - current_len, _UU(name), args);
va_end(args);
WriteServerLog(server->Cedar, message);
}
WG_MSG_TYPE WgsDetectMessageType(const void *data, const UINT size)
{
const WG_COMMON *packet = data;
if (packet == NULL || size < sizeof(WG_COMMON))
{
return WG_MSG_INVALID;
}
switch (packet->Header.Type)
{
case WG_MSG_HANDSHAKE_INIT:
if (size != sizeof(WG_HANDSHAKE_INIT))
{
return WG_MSG_INVALID;
}
break;
case WG_MSG_HANDSHAKE_REPLY:
if (size != sizeof(WG_HANDSHAKE_REPLY))
{
return WG_MSG_INVALID;
}
break;
case WG_MSG_HANDSHAKE_COOKIE:
if (size != sizeof(WG_COOKIE_REPLY))
{
return WG_MSG_INVALID;
}
break;
case WG_MSG_TRANSPORT_DATA:
if (size < sizeof(WG_TRANSPORT_DATA) + WG_AEAD_SIZE(0))
{
return WG_MSG_INVALID;
}
break;
default:
return WG_MSG_INVALID;
}
if (IsZero(packet->Header.Reserved, sizeof(packet->Header.Reserved)) == false)
{
return WG_MSG_INVALID;
}
if (packet->Index == 0)
{
return WG_MSG_INVALID;
}
return packet->Header.Type;
}
UINT WgsMSS(const WG_SESSION *session)
{
UINT ret = MTU_FOR_PPPOE;
if (session == NULL)
{
return 0;
}
// IPv4 / IPv6
if (IsIP4(&session->IPRemote))
{
ret -= 20;
}
else
{
ret -= 40;
}
// UDP
ret -= 8;
// WireGuard packet
ret -= sizeof(WG_TRANSPORT_DATA);
// Inner IPv4
ret -= 20;
// Inner TCP
ret -= 20;
return ret;
}
IPC *WgsIPCNew(WG_SERVER *server)
{
UINT err;
IPC *ipc;
IPC_PARAM param;
WG_SESSION *session;
if (server == NULL)
{
return NULL;
}
session = &server->Session;
Zero(&param, sizeof(param));
StrCpy(param.ClientName, sizeof(param.ClientName), WgsName());
StrCpy(param.Postfix, sizeof(param.Postfix), WG_IPC_POSTFIX);
sodium_bin2base64(param.WgKey, sizeof(param.WgKey), session->StaticRemote, sizeof(session->StaticRemote), sodium_base64_VARIANT_ORIGINAL);
Copy(&param.ServerIp, &session->IPLocal, sizeof(param.ServerIp));
Copy(&param.ClientIp, &session->IPRemote, sizeof(param.ClientIp));
param.ServerPort = session->PortLocal;
param.ClientPort = session->PortRemote;
StrCpy(param.CryptName, sizeof(param.CryptName), WG_CIPHER);
param.Layer = IPC_LAYER_3;
param.Mss = WgsMSS(session);
ipc = NewIPCByParam(server->Cedar, &param, &err);
if (ipc == NULL)
{
Debug("WgsIPCNew(): NewIPCByParam() failed with error %u!\n", err);
}
return ipc;
}
WG_KEYPAIR *WgsProcessHandshakeInit(WG_SERVER *server, const WG_HANDSHAKE_INIT *init, BYTE *ephemeral_remote)
{
WG_SESSION *session;
WG_KEYPAIR *keypair = NULL;
BYTE hash[WG_HASH_SIZE];
BYTE key[WG_KEY_SIZE];
BYTE chaining_key[WG_HASH_SIZE];
BYTE timestamp[WG_TIMESTAMP_SIZE];
BYTE static_remote[WG_KEY_SIZE];
if (server == NULL || init == NULL || ephemeral_remote == NULL)
{
return NULL;
}
session = &server->Session;
Copy(hash, server->HandshakeInitHash, sizeof(server->HandshakeInitHash));
Copy(chaining_key, server->HandshakeInitChainingKey, sizeof(server->HandshakeInitChainingKey));
WgsMixHash(hash, server->StaticPublic, sizeof(server->StaticPublic));
WgsEphemeral(ephemeral_remote, init->UnencryptedEphemeral, chaining_key, hash);
if (WgsMixDh(chaining_key, key, server->StaticPrivate, ephemeral_remote) == 0)
{
Debug("WgsProcessHandshakeInit(): WgsMixDh() failed!\n");
goto FINAL;
}
if (WgsDecryptWithHash(static_remote, init->EncryptedStatic, sizeof(init->EncryptedStatic), hash, key) == false)
{
Debug("WgsProcessHandshakeInit(): WgsDecryptWithHash() failed to decrypt the static key!\n");
goto FINAL;
}
if (IsZero(session->StaticRemote, sizeof(session->StaticRemote)) == false)
{
if (Cmp(static_remote, session->StaticRemote, sizeof(static_remote)) != 0)
{
Debug("WgsProcessHandshakeInit(): static remote key doesn't match!\n");
goto FINAL;
}
}
if (IsZero(session->PrecomputedStaticStatic, sizeof(session->PrecomputedStaticStatic)))
{
Debug("WgsProcessHandshakeInit(): precomputing static static...\n");
if (crypto_scalarmult_curve25519(session->PrecomputedStaticStatic, server->StaticPrivate, static_remote) != 0)
{
Debug("WgsProcessHandshakeInit(): crypto_scalarmult_curve25519() failed!\n");
goto FINAL;
}
}
WgsHKDF(chaining_key, key, NULL, session->PrecomputedStaticStatic, sizeof(session->PrecomputedStaticStatic), chaining_key);
if (WgsDecryptWithHash(&timestamp, init->EncryptedTimestamp, sizeof(init->EncryptedTimestamp), hash, key) == false)
{
Debug("WgsProcessHandshakeInit(): WgsDecrypt() failed to decrypt the timestamp!\n");
goto FINAL;
}
if (Cmp(&timestamp, session->LastTimestamp, sizeof(timestamp) <= 0))
{
WgsLog(server, "LW_REPLAY_ATTACK");
goto FINAL;
}
Copy(session->LastTimestamp, &timestamp, sizeof(session->LastTimestamp));
Copy(session->Hash, hash, sizeof(session->Hash));
Copy(session->ChainingKey, chaining_key, sizeof(session->ChainingKey));
Copy(session->StaticRemote, static_remote, sizeof(session->StaticRemote));
keypair = ZeroMalloc(sizeof(WG_KEYPAIR));
keypair->State = WG_KEYPAIR_INITIATED;
keypair->CreationTime = server->Now;
keypair->IndexLocal = Rand32();
keypair->IndexRemote = init->SenderIndex;
FINAL:
Zero(key, sizeof(key));
Zero(hash, sizeof(hash));
Zero(chaining_key, sizeof(chaining_key));
Zero(static_remote, sizeof(static_remote));
return keypair;
}
WG_HANDSHAKE_REPLY *WgsCreateHandshakeReply(WG_SERVER *server, WG_KEYPAIR *keypair, const BYTE *ephemeral_remote)
{
bool ok = false;
WG_SESSION *session;
WG_HANDSHAKE_REPLY *ret;
BYTE hash[WG_HASH_SIZE];
BYTE key[WG_KEY_SIZE];
BYTE ephemeral[WG_KEY_SIZE];
if (server == NULL || keypair == NULL || ephemeral_remote == NULL)
{
return NULL;
}
if (keypair->State != WG_KEYPAIR_INITIATED)
{
Debug("WgsCreateHandshakeReply(): unexpected keypair state %u!\n", keypair->State);
return NULL;
}
session = &server->Session;
ret = ZeroMalloc(sizeof(WG_HANDSHAKE_REPLY));
ret->Header.Type = WG_MSG_HANDSHAKE_REPLY;
ret->SenderIndex = keypair->IndexLocal;
ret->ReceiverIndex = keypair->IndexRemote;
crypto_box_curve25519xsalsa20poly1305_keypair(ret->UnencryptedEphemeral, ephemeral);
WgsEphemeral(ret->UnencryptedEphemeral, ret->UnencryptedEphemeral, session->ChainingKey, session->Hash);
if (WgsMixDh(session->ChainingKey, NULL, ephemeral, ephemeral_remote) == 0)
{
Debug("WgsCreateHandshakeReply(): WgsMixDh() failed to mix ephemeral public!\n");
goto FINAL;
}
if (WgsMixDh(session->ChainingKey, NULL, ephemeral, session->StaticRemote) == 0)
{
Debug("WgsCreateHandshakeReply(): WgsMixDh() failed to mix static public!\n");
goto FINAL;
}
WgsHKDF(session->ChainingKey, hash, key, server->PresharedKey, sizeof(server->PresharedKey), session->ChainingKey);
WgsMixHash(session->Hash, hash, sizeof(hash));
if (WgsEncryptWithHash(ret->EncryptedNothing, NULL, 0, session->Hash, key) == false)
{
Debug("WgsCreateHandshakeReply(): WgsEncryptWithHash() failed!\n");
goto FINAL;
}
WgsMixHash(session->Hash, ret->EncryptedNothing, sizeof(ret->EncryptedNothing));
blake2s_state blake;
blake2s_init(&blake, sizeof(key));
blake2s_update(&blake, WG_LABEL_MAC1, StrLen(WG_LABEL_MAC1));
blake2s_update(&blake, session->StaticRemote, sizeof(session->StaticRemote));
blake2s_final(&blake, key, sizeof(key));
blake2s(ret->Macs.Mac1, sizeof(ret->Macs.Mac1), ret, sizeof(WG_HANDSHAKE_REPLY) - sizeof(WG_MACS), key, sizeof(key));
ok = true;
FINAL:
Zero(key, sizeof(key));
Zero(hash, sizeof(hash));
Zero(ephemeral, sizeof(ephemeral));
if (ok)
{
WG_KEYPAIRS *keypairs = &session->Keypairs;
WgsHKDF(keypair->KeyRemote, keypair->KeyLocal, NULL, NULL, 0, session->ChainingKey);
keypair->State = WG_KEYPAIR_CONFIRMED;
Debug("WgsCreateHandshakeReply(): new keypair available: %x -> %x\n", keypair->IndexRemote, keypair->IndexLocal);
if (keypairs->Next != NULL)
{
WG_KEYPAIR *next = keypairs->Next;
Debug("WgsCreateHandshakeReply(): deleting keypair: %x -> %x\n", next->IndexRemote, next->IndexLocal);
Zero(next, sizeof(WG_KEYPAIR));
Free(next);
}
if (keypairs->Current == NULL)
{
Debug("WgsCreateHandshakeReply(): switched to keypair: %x -> %x\n", keypair->IndexRemote, keypair->IndexLocal);
keypairs->Current = keypair;
keypairs->Next = NULL;
return ret;
}
keypairs->Next = keypair;
return ret;
}
Zero(ret, sizeof(WG_HANDSHAKE_REPLY));
Free(ret);
return NULL;
}
bool WgsProcessTransportData(WG_SERVER *server, WG_TRANSPORT_DATA *data, const UINT size)
{
UINT written;
UINT encrypted_size;
WG_KEYPAIR *keypair;
WG_KEYPAIRS *keypairs;
if (server == NULL || data == NULL || size < sizeof(WG_TRANSPORT_DATA))
{
return false;
}
encrypted_size = size - sizeof(WG_TRANSPORT_DATA);
if (encrypted_size < WG_TAG_SIZE)
{
return false;
}
keypairs = &server->Session.Keypairs;
keypair = keypairs->Current;
if (data->ReceiverIndex != keypair->IndexLocal)
{
WG_KEYPAIR *previous = keypairs->Previous;
if (keypairs->Next != NULL && data->ReceiverIndex == keypairs->Next->IndexLocal)
{
if (previous != NULL)
{
Debug("WgsProcessTransportData(): deleting keypair: %x -> %x\n", previous->IndexRemote, previous->IndexLocal);
Zero(previous, sizeof(WG_KEYPAIR));
Free(previous);
}
keypairs->Previous = keypair;
keypairs->Current = keypair = keypairs->Next;
keypairs->Next = NULL;
Debug("WgsProcessTransportData(): switched to keypair: %x -> %x\n", keypair->IndexRemote, keypair->IndexLocal);
}
else if (previous != NULL && data->ReceiverIndex == previous->IndexLocal)
{
keypair = previous;
}
else
{
WgsLog(server, "LW_KEYPAIR_UNKNOWN");
return false;
}
}
if (WgsIsInReplayWindow(keypair, data->Counter))
{
WgsLog(server, "LW_REPLAY_ATTACK");
return false;
}
written = WgsDecryptData(keypair->KeyRemote, data->Counter, data->EncapsulatedPacket, data->EncapsulatedPacket, encrypted_size);
if (written == INFINITE)
{
WgsLog(server, "LW_DECRYPT_FAIL");
return false;
}
if (data->Counter > WG_REJECT_AFTER_MESSAGES)
{
WgsLog(server, "LW_KEYPAIR_EXPIRED", keypair->IndexRemote, keypair->IndexLocal);
return false;
}
WgsUpdateReplayWindow(keypair, data->Counter);
if (written > 0)
{
WG_SESSION *session = &server->Session;
if (session->IPC == NULL)
{
IP ip;
PKT pkt;
IPC *ipc;
ipc = WgsIPCNew(server);
if (ipc == NULL)
{
Debug("WgsProcessTransportData(): WgsCreateIPC() returned NULL!\n");
return false;
}
if (ParsePacketIPv4(&pkt, data->EncapsulatedPacket, written) == false)
{
Debug("WgsProcessTransportData(): ParsePacketIPv4() failed!\n");
return false;
}
UINTToIP(&ip, pkt.L3.IPv4Header->SrcIP);
IPCSetIPv4Parameters(ipc, &ip, &ipc->SubnetMask, &ipc->DefaultGateway, NULL);
IPCSetSockEventWhenRecvL2Packet(ipc, server->SockEvent);
IPC_PROTO_SET_STATUS(ipc, IPv4State, IPC_PROTO_STATUS_OPENED);
session->IPC = ipc;
}
IPCSendIPv4(session->IPC, data->EncapsulatedPacket, written);
}
return true;
}
WG_TRANSPORT_DATA *WgsCreateTransportData(WG_SERVER *server, const void *data, const UINT size, UINT *final_size)
{
UINT pad_size;
UINT encrypted_size;
WG_KEYPAIR *keypair;
WG_TRANSPORT_DATA *ret;
if (server == NULL || (data == NULL && size > 0) || final_size == NULL)
{
return NULL;
}
keypair = server->Session.Keypairs.Current;
if (keypair == NULL)
{
Debug("WgsCreateTransportData(): no keypair!\n");
return NULL;
}
if (keypair->CounterLocal > WG_REJECT_AFTER_MESSAGES)
{
WgsLog(server, "LW_KEYPAIR_EXPIRED", keypair->IndexRemote, keypair->IndexLocal);
return false;
}
pad_size = (WG_BLOCK_SIZE - (size % WG_BLOCK_SIZE)) % WG_BLOCK_SIZE;
encrypted_size = WG_AEAD_SIZE(size + pad_size);
*final_size = sizeof(WG_TRANSPORT_DATA) + encrypted_size;
ret = ZeroMalloc(*final_size);
ret->Header.Type = WG_MSG_TRANSPORT_DATA;
ret->ReceiverIndex = keypair->IndexRemote;
ret->Counter = keypair->CounterLocal;
Copy(ret->EncapsulatedPacket, data, size);
if (WgsEncryptData(keypair->KeyLocal, ret->Counter, ret->EncapsulatedPacket, ret->EncapsulatedPacket, size + pad_size) != encrypted_size)
{
Debug("WgsCreateTransportData(): WgsEncryptData() didn't write the expected number of bytes!\n");
Free(ret);
return NULL;
}
++keypair->CounterLocal;
return ret;
}
// RFC 6479: ipsec_check_replay_window()
bool WgsIsInReplayWindow(const WG_KEYPAIR *keypair, const UINT64 counter)
{
int bit_location;
int index;
if (keypair == NULL || counter == 0)
{
return false;
}
if (counter > keypair->CounterRemote)
{
return false;
}
if (counter + sizeof(keypair->ReplayWindow) < keypair->CounterRemote)
{
return false;
}
bit_location = counter & WG_REPLAY_BITMAP_LOC_MASK;
index = counter >> WG_REPLAY_REDUNDANT_BIT_SHIFTS & WG_REPLAY_BITMAP_INDEX_MASK;
if (keypair->ReplayWindow[index] & (1 << bit_location))
{
return true;
}
return false;
}
// RFC 6479: ipsec_update_replay_window()
void WgsUpdateReplayWindow(WG_KEYPAIR *keypair, const UINT64 counter)
{
int bit_location;
int index;
if (keypair == NULL || counter == 0)
{
return;
}
if (counter + sizeof(keypair->ReplayWindow) < keypair->CounterRemote)
{
return;
}
index = counter >> WG_REPLAY_REDUNDANT_BIT_SHIFTS;
if (counter > keypair->CounterRemote)
{
const int index_cur = keypair->CounterRemote >> WG_REPLAY_REDUNDANT_BIT_SHIFTS;
int diff = index - index_cur;
int id;
if (diff > WG_REPLAY_BITMAP_SIZE)
{
diff = WG_REPLAY_BITMAP_SIZE;
}
for (id = 0; id < diff; ++id)
{
keypair->ReplayWindow[(id + index_cur + 1) & WG_REPLAY_BITMAP_INDEX_MASK] = 0;
}
keypair->CounterRemote = counter;
}
index &= WG_REPLAY_BITMAP_INDEX_MASK;
bit_location = counter & WG_REPLAY_BITMAP_LOC_MASK;
if (keypair->ReplayWindow[index] & 1 << bit_location)
{
return;
}
keypair->ReplayWindow[index] |= 1 << bit_location;
}
UINT WgsEncryptData(void *key, const UINT64 counter, void *dst, const void *src, const UINT src_size)
{
unsigned long long written;
BYTE iv[WG_IV_SIZE];
if (key == NULL || dst == NULL || (src == NULL && src_size > 0))
{
return INFINITE;
}
Zero(iv, sizeof(iv) - sizeof(counter));
Copy(iv + sizeof(iv) - sizeof(counter), &counter, sizeof(counter));
crypto_aead_chacha20poly1305_ietf_encrypt(dst, &written, src, src_size, NULL, 0, NULL, iv, key);
return written;
}
UINT WgsDecryptData(void *key, const UINT64 counter, void *dst, const void *src, const UINT src_size)
{
unsigned long long written;
BYTE iv[WG_IV_SIZE];
if (key == NULL || src == NULL || src_size == 0)
{
return INFINITE;
}
Zero(iv, sizeof(iv) - sizeof(counter));
Copy(iv + sizeof(iv) - sizeof(counter), &counter, sizeof(counter));
if (crypto_aead_chacha20poly1305_ietf_decrypt(dst, &written, NULL, src, src_size, NULL, 0, iv, key) != 0)
{
return INFINITE;
}
return written;
}
bool WgsEncryptWithHash(void *dst, const void *src, const UINT src_size, BYTE *hash, const BYTE *key)
{
unsigned long long written;
BYTE iv[WG_IV_SIZE];
if (dst == NULL || (src == NULL && src_size > 0) || hash == NULL || key == NULL)
{
return false;
}
Zero(iv, sizeof(iv));
crypto_aead_chacha20poly1305_ietf_encrypt(dst, &written, src, src_size, hash, WG_HASH_SIZE, NULL, iv, key);
WgsMixHash(hash, dst, WG_AEAD_SIZE(src_size));
return (written > 0);
}
bool WgsDecryptWithHash(void *dst, const void *src, const UINT src_size, BYTE *hash, const BYTE *key)
{
unsigned long long written;
BYTE iv[WG_IV_SIZE];
if ((src == NULL && src_size > 0) || hash == NULL || key == NULL)
{
return false;
}
Zero(iv, sizeof(iv));
if (crypto_aead_chacha20poly1305_ietf_decrypt(dst, &written, NULL, src, src_size, hash, WG_HASH_SIZE, iv, key) != 0)
{
return false;
}
WgsMixHash(hash, src, src_size);
return (written > 0);
}
void WgsEphemeral(BYTE *ephemeral_dst, const BYTE *ephemeral_src, BYTE *chaining_key, BYTE *hash)
{
Copy(ephemeral_dst, ephemeral_src, WG_KEY_SIZE);
WgsMixHash(hash, ephemeral_src, WG_HASH_SIZE);
WgsHKDF(chaining_key, NULL, NULL, ephemeral_src, WG_KEY_SIZE, chaining_key);
}
void WgsHKDF(BYTE *dst_1, BYTE *dst_2, BYTE *dst_3, const BYTE *data, const UINT data_size, const BYTE *chaining_key)
{
BYTE output[WG_HASH_SIZE + 1];
BYTE secret[WG_HASH_SIZE];
MD *md = NewMd("BLAKE2s256");
SetMdKey(md, chaining_key, WG_HASH_SIZE);
// Extract entropy from data into secret.
MdProcess(md, secret, data, data_size);
if (dst_1 == NULL)
{
goto FINAL;
}
SetMdKey(md, secret, sizeof(secret));
// Expand first key
output[0] = 1;
MdProcess(md, output, output, 1);
Copy(dst_1, output, WG_KEY_SIZE);
if (dst_2 == NULL)
{
goto FINAL;
}
// Expand second key
output[sizeof(output) - 1] = 2;
MdProcess(md, output, output, sizeof(output));
Copy(dst_2, output, WG_KEY_SIZE);
if (dst_3 == NULL)
{
goto FINAL;
}
// Expand third key
output[sizeof(output) - 1] = 3;
MdProcess(md, output, output, sizeof(output));
Copy(dst_3, output, WG_KEY_SIZE);
FINAL:
FreeMd(md);
Zero(secret, sizeof(secret));
Zero(output, sizeof(output));
}
void WgsMixHash(void *dst, const void *src, const UINT size)
{
blake2s_state b2s_state;
if (dst == NULL || (src == NULL && size > 0))
{
return;
}
blake2s_init(&b2s_state, WG_HASH_SIZE);
blake2s_update(&b2s_state, dst, WG_HASH_SIZE);
blake2s_update(&b2s_state, src, size);
blake2s_final(&b2s_state, dst, WG_HASH_SIZE);
}
bool WgsMixDh(BYTE *chaining_key, BYTE *key, const BYTE *priv, const BYTE *pub)
{
BYTE dh[WG_HASH_SIZE];
if (chaining_key == NULL || priv == NULL || pub == NULL)
{
return false;
}
if (crypto_scalarmult_curve25519(dh, priv, pub) != 0)
{
Debug("WgsMixDh(): crypto_scalarmult_curve25519() failed!\n");
return false;
}
WgsHKDF(chaining_key, key, NULL, dh, sizeof(dh), chaining_key);
Zero(dh, sizeof(dh));
return true;
}