Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

from . import encoding 

from .exceptions import DuplicateDomainException, NotReadyException, NoSuchDomainException 

 

import decorator 

import getpass 

import math 

import os 

import requests 

import scrypt 

import sqlalchemy as sa 

import sys 

import time 

import traceback 

from sqlalchemy.ext.declarative import declarative_base 

from sqlalchemy.orm import sessionmaker 

from logging import getLogger 

 

 

Base = declarative_base() 

_logger = getLogger('pwm.core') 

 

 

class Domain(Base): 

    """ Domain objects hold all the data for a given domain name. 

 

    Domain names can in theory be anything, from user selected aliases to actual domain names like 

    facebook.com or twitter.com, however the latter is probably recommended as it opens up the 

    possiblity to automatically extract the relevant objects if the user visists the site, such as 

    in a browser extension of similar. 

 

    :param name: The identifier for this domain. 

    :param alpabet: The alpabet to restrict key contents to. Default: 'full' 

    :param key_length: The length of the computed key. Can be useful if the site imposes restrictions 

        on password length. Default: 16 

    """ 

    DEFAULT_KEY_LENGTH = 16 

    DEFAULT_ALPHABET = 'full' 

 

    __tablename__ = 'domain' 

    id = sa.Column(sa.Integer, primary_key=True) 

    name = sa.Column(sa.String(30), unique=True) 

    salt = sa.Column(sa.LargeBinary(128)) 

    charset = sa.Column(sa.String(128)) 

    key_length = sa.Column(sa.Integer()) 

    username = sa.Column(sa.String(40)) 

 

 

    def __init__(self, alphabet=DEFAULT_ALPHABET, key_length=DEFAULT_KEY_LENGTH, **kwargs): 

51        if alphabet: 

            self.charset = encoding.lookup_alphabet(alphabet) 

        super(Domain, self).__init__(key_length=key_length, **kwargs) 

        if not 'salt' in kwargs: 

            self.new_salt() 

 

 

    @property 

    def entropy(self): 

        unique_chars = len(set(self.charset)) 

        entropy = -math.log(1.0/(unique_chars**self.key_length), 2) 

        return entropy 

 

 

    def new_salt(self): 

        self.salt = os.urandom(32) 

 

 

    def derive_key(self, master_password): 

        """ Computes the key from the salt and the master password. """ 

        encoder = encoding.Encoder(self.charset) 

 

        bytes = ('%s:%s' % (master_password, self.name)).encode('utf8') 

 

        start_time = time.clock() 

        # we fix the scrypt parameters in case the defaults change 

        digest = scrypt.hash(bytes, self.salt, N=1<<14, r=8, p=1) 

 

        key = encoder.encode(digest, self.key_length) 

        derivation_time_in_s = time.clock() - start_time 

 

        _logger.debug('Key derivation took %.2fms', derivation_time_in_s*1000) 

        return key 

 

 

    def get_key(self): 

        """ Fetches the key for the domain. Prompts the user for password. 

 

        Thin wrapper around :func:`Domain.derive_key <pwm.core.Domain.derive_key>`. 

        """ 

        master_password = getpass.getpass('Enter your master password: ') 

        return self.derive_key(master_password) 

 

 

    def __repr__(self): # pragma: no cover 

        return 'Domain(name=%s, salt=%s, charset=%s, key_length=%s)' \ 

                % (self.name, self.salt, self.charset, self.key_length) 

 

 

def _urify_db(path_or_uri): 

    """ Get a SQLAlchemy compatible database URI. 

 

    If a path is given, append sqlite:/// in the front, if protocol details are already provided, 

    return unchanged. 

    """ 

    if '://' in path_or_uri: 

        return path_or_uri 

    else: 

        return 'sqlite:///%s' % path_or_uri 

 

 

 

@decorator.decorator 

def _uses_db(func, self, *args, **kwargs): 

    """ Use as a decorator for operations on the database, to ensure connection setup and 

    teardown. Can only be used on methods on objects with a `self.session` attribute. 

    """ 

    if not self.session: 

        _logger.debug('Creating new db session') 

        self._init_db_session() 

    try: 

        ret = func(self, *args, **kwargs) 

        self.session.commit() 

    except: 

        self.session.rollback() 

        tb = traceback.format_exc() 

        _logger.debug(tb) 

        raise 

    finally: 

        _logger.debug('Closing db session') 

        self.session.close() 

    return ret 

 

 

class PWM(object): 

    """ This is the main object for interfacing with a pwm database. 

 

    :param database_path: The path to the database to use, or a SQLAlchemy-compatible connection 

        URI, like `postgresql://user:pw@host/db`. If not given or None, 

        :func:`PWM.bootstrap <pwm.core.PWM.bootstrap` must be called before doing any operations 

        that operate on the database. 

    """ 

 

    def __init__(self, database_uri=None): 

        self.session = None 

        self.database_uri = _urify_db(database_uri) if database_uri else None 

 

 

    def bootstrap(self, path_or_uri): 

        """ Initialize a database. 

 

        :param database_path: The absolute path to the database to initialize. 

        """ 

        _logger.debug("Bootstrapping new database: %s", path_or_uri) 

        self.database_uri = _urify_db(path_or_uri) 

        db = sa.create_engine(self.database_uri) 

        Base.metadata.create_all(db) 

 

 

    @_uses_db 

    def search(self, query): 

        """ Search the database for the given query. Will find partial matches. """ 

        results = self.session.query(Domain).filter(Domain.name.ilike('%%%s%%' % query)).all() 

        return results 

 

 

    @_uses_db 

    def get_domain(self, domain_name): 

        """ Get the :class:`Domain <pwm.Domain>` object from a name. 

 

        :param domain_name: The domain name to fetch the object for. 

        :returns: The :class:`Domain <pwm.core.Domain>` class with this domain_name if found, else 

            None. 

        """ 

        protocol = self.database_uri.split(':', 1)[0] 

175        if protocol in ('https', 'http'): 

            return self._get_domain_from_rest_api(domain_name) 

        else: 

            domain = self._get_domain_from_db(domain_name) 

            if domain: 

                return domain 

            else: 

                raise NoSuchDomainException 

 

 

    def _get_domain_from_rest_api(self, domain): 

        request_args = { 

            'params': {'domain': domain} 

        } 

        verify = True 

        server_certificate = self.config.get('server_certificate') 

        if server_certificate: 

            verify = os.path.join(os.path.dirname(server_certificate), server_certificate) 

            _logger.debug('Pinning server with certificate at %s', verify) 

 

        # Test for SNI support on python 2 

        if sys.version_info < (3, 0, 0): 

            try: 

                import urllib3.contrib.pyopenssl 

                urllib3.contrib.pyopenssl.inject_into_urllib3() 

            except ImportError: 

                _logger.warning("Running on python 2 without SNI support, can't verify server certificates.") 

                verify = False 

        request_args['verify'] = verify 

 

        if self.config.get('auth'): 

            request_args['cert'] = self.config['auth'] 

        response = requests.get(self.config['database'] + '/get', **request_args) 

        domain = Domain(name=domain, salt=response.json()['salt']) 

        return domain 

 

 

    def _get_domain_from_db(self, domain_name): 

        domain = self.session.query(Domain).filter(Domain.name == domain_name).first() 

        return domain 

 

 

    @_uses_db 

    def modify_domain(self, domain_name, new_salt=False, username=None): 

        """ Modify an existing domain. 

 

        :param domain_name: The name of the domain to modify. 

        :param new_salt: Whether to generate a new salt for the domain. 

        :param username: If given, change domain username to this value. 

        :returns: The modified :class:`Domain <pwm.core.Domain>` object. 

        """ 

        domain = self._get_domain_from_db(domain_name) 

        if domain is None: 

            raise NoSuchDomainException 

231        if new_salt: 

            _logger.info("Generating new salt..") 

            domain.new_salt() 

233        if username is not None: 

            domain.username = username 

        return domain 

 

 

    def create_domain(self, domain_name, username=None, alphabet=Domain.DEFAULT_ALPHABET, 

            length=Domain.DEFAULT_KEY_LENGTH): 

        """ Create a new domain entry in the database. 

 

        :param username: The username to associate with this domain. 

        :param alphabet: A character set restriction to impose on keys generated for this domain. 

        :param length: The length of the generated key, in case of restrictions on the site. 

        """ 

        # Wrap the actual implementation to do some error handling 

        try: 

            return self._create_domain(domain_name, username, alphabet, length) 

        except Exception as ex: 

            _logger.warn("Inserting new domain failed: %s", ex) 

            raise DuplicateDomainException 

 

 

    @_uses_db 

    def _create_domain(self, domain_name, username, alphabet, length): 

        domain = Domain(name=domain_name, username=username, key_length=length, 

            alphabet=alphabet) 

        self.session.add(domain) 

        return domain 

 

 

    def _init_db_session(self): 

        if not self.database_uri: 

            raise NotReadyException() 

        db = sa.create_engine(self.database_uri) 

        DBSession = sessionmaker(bind=db, expire_on_commit=False) 

        self.session = DBSession()