Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

from . import AuthenticatedMessage 

from .enums import ServerState 

from .utils import ascii_bin 

from .sessions import ClientSession, ServerSession 

 

from contextlib import contextmanager 

from logging import getLogger 

import socket 

 

 

_logger = getLogger('nuts.channel') 

 

 

class AuthChannel(object): 

    """ Generic, transport-agnostic authenticated channel. Needs to be overridden by class 

    implementing `receive` and `listen`. 

    """ 

    #: If the underlying transport layer has a MTU, child classes should set 

    #: this to the same MTU 

    mtu = None 

 

    #: MACs supported by this satellite/server. Used in the SA negotiation, 

    #: should be ordered by preference (strength). 

    supported_macs = [ 

        'sha3_512', 

        'sha3_384', 

        'sha3_256', 

        'hmac-sha1', 

        'hmac-sha256', 

    ] 

 

 

    def __init__(self, path_to_keyfile, *args, **kwargs): 

        """ Create a new auth channel context to keep around. """ 

        self._path_to_keyfile = path_to_keyfile 

        with open(path_to_keyfile, 'rb') as key_fh: 

            self._shared_key = key_fh.read().strip() 

        self.sessions = {} 

        self._messages = [] 

 

 

    @property 

    def shared_key(self): 

        return self._shared_key 

 

 

    @shared_key.setter 

    def shared_key(self, value): 

        with open(self._path_to_keyfile, 'wb') as key_fh: 

            key_fh.write(value) 

        self._shared_key = value 

 

 

    def receive(self): 

        while not self._messages: 

            _logger.info('Listening...') 

            data, sender = self.read_data() 

            message = AuthenticatedMessage(sender, data) 

            self.handle_message(message) 

        return self._messages.pop(0) 

 

 

    @contextmanager 

    def connect(self, address): 

        session = ClientSession(address, self) 

        session.connect() 

 

        # Session setup complete, let client use the session 

        yield session 

 

        # Send terminate 

        _logger.info('Terminating session with %s...', session.id_a) 

        session.terminate() 

 

        # Release any resources held by the channel 

        self.tear_down() 

 

 

    def tear_down(self): 

        """ Release any resources used by the channel, if any. """ 

        pass 

 

 

    def _send(self, data, address): 

        self.send_data(data, address) 

 

 

    def send(self, data, address): 

        """ Externally exposed interface for sending data. """ 

        session = self.sessions[address] 

        session.send(data) 

 

 

    def handle_message(self, message): 

        """ Handle incoming message on the channel. """ 

        if message.sender in self.sessions: 

            session = self.sessions.get(message.sender) 

        else: 

            session = ServerSession(message.sender, self) 

            self.sessions[message.sender] = session 

        session.handle(message.msg) 

        if session.state == ServerState.inactive: 

            _logger.info('Terminating session with %s' % str(message.sender)) 

            del self.sessions[message.sender] 

        elif session.state == ServerState.rekey_confirmed: 

            _logger.info('Rekey confirmed, new master key in place, invalidating all existing sessions..') 

            self.sessions = {} 

            _logger.info('Session invalidated, shared key updated') 

            self.shared_key = session.shared_key 

 

 

class UDPAuthChannel(AuthChannel): 

 

    #: The maximum size of packets sent and received on this channel 

    mtu = 4096 

 

    def __init__(self, *args, **kwargs): 

        super(UDPAuthChannel, self).__init__(*args, **kwargs) 

        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 

        timeout = kwargs.get('timeout', 2.0) 

        self.sock.settimeout(timeout) 

 

 

    def listen(self, address): 

        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 

        try: 

            self.sock.bind(address) 

            _logger.info('Bound to %s:%s' % address) 

        except Exception as e: 

            _logger.exception('Exception occured while binding to socket, address %s.', address) 

            raise 

 

 

    def send_data(self, data, address): 

        _logger.debug('Sending %s to %s' % (ascii_bin(data), address)) 

        self.sock.sendto(data, address) 

 

 

    def read_data(self): 

        data, sender = self.sock.recvfrom(self.mtu) 

        _logger.debug('Received data: %s from %s' % (ascii_bin(data), sender)) 

        return data, sender 

 

 

    def tear_down(self): 

        self.sock.close() 

 

 

class DummyAuthChannel(AuthChannel): 

    """ Only return stuff locally, probably only useful for testing. """ 

 

    def __init__(self, *args, **kwargs): 

        super(DummyAuthChannel, self).__init__(*args, **kwargs) 

        self.sent_messages = [] 

        self.messages_to_receive = [] 

 

 

    def send_data(self, data, address): 

        _logger.debug('Sending data %s to %s' % (ascii_bin(data), address)) 

        self.sent_messages.append(AuthenticatedMessage(address, data)) 

 

 

    def read_data(self): 

        """ Return pre-filled replies, or None. """ 

        if self.messages_to_receive: 

            return self.messages_to_receive.pop(0)