-
Notifications
You must be signed in to change notification settings - Fork 0
/
source_coder.py
115 lines (108 loc) · 4.08 KB
/
source_coder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/python
# Contains functions for compressing/recovering (english) text. Input to both functions
# is a generative model which predicts a probability distribution over the possible characters.
# A list of allowed characters. The model must return a map of probabilities whose domain
# consists of this list. Note that this list contains all letters, numbers, and special symbols
# so this shouldn't be at all restrictive.
character_list = [chr(i) for i in range(32,128)]
# The resolution of the encoding process, i.e. probabilities generated by the model are
# incorporated at least to an accuracy of 1/resolution. WARNING: The compression and decompression
# process should always succeed. However, the decompression may fail to recover the original
# message if the probability of a character is given as 0 by the model (UP TO THE RESOLUTION).
resolution = 1073741824;
def compress(text, model): # Returns the length of the message and a list of booleans representing the encrypted message.
compressed_message = []
bottom = 0
top = 2*resolution
max_val = top
for a in text + chr(127): # add EOF marker to end of message
probs = model.probabilities()
current_res = top - bottom
while current_res < resolution:
top *= 2
bottom *= 2
max_val *= 2
current_res *= 2
temp = bottom + current_res
for x in [i for i in character_list if ord(i) < ord(a)]:
bottom += int(probs[x]*current_res)
if ord(a) < 127:
top = bottom + int(probs[a]*current_res)
else:
top = temp
while bottom >= max_val // 2 or top < max_val // 2:
if bottom >= max_val // 2:
compressed_message.append(True) # represents a 1.
bottom -= (max_val // 2)
top -= (max_val // 2)
max_val = max_val // 2
elif top < max_val // 2:
compressed_message.append(False) # represents a 0.
max_val = max_val // 2
model.next_char(a) # Tells the model what the current character is so it can update the probability distribution of the next character.
while bottom > 0 or top < max_val:
if (top - (max_val // 2)) > ((max_val // 2) - bottom):
top = top - max_val // 2
bottom = bottom - max_val // 2
if bottom < 0:
bottom = 0
max_val = max_val // 2
compressed_message.append(True)
else:
max_val = max_val // 2
if top > max_val:
top = max_val
compressed_message.append(False)
return compressed_message
def decompress(message, model):
compressed_message = message
text = []
bottom = 0
top = 2*resolution
max_val = top
message_position = 0
while len(text) == 0 or text[-1] != chr(127):
probs = model.probabilities()
current_res = top - bottom
while current_res < resolution:
top *= 2
bottom *= 2
max_val *= 2
current_res *= 2
temp = bottom
character_intervals = []
character_intervals.append(temp)
for x in [i for i in character_list if ord(i) < 127]:
temp += int(probs[x]*current_res)
character_intervals.append(temp)
character_intervals.append(bottom + current_res)
temp_pos = message_position
determined = False
temp_bottom = 0
temp_top = max_val
next_char_index = -1
while not determined:
if len(compressed_message) <= temp_pos: # the compressed message does not specify a message ending in the EOF character
return text # There must have been an error and so we return the text recovered so far.
if compressed_message[temp_pos]:
temp_bottom = (temp_bottom + temp_top) // 2
else:
temp_top = (temp_bottom + temp_top) // 2
temp_pos += 1
intervals_lower = [a for a in character_intervals if a < temp_top]
if intervals_lower[-1] <= temp_bottom: # If the next character has been determined, exit the loop
text.append(character_list[len(intervals_lower)-1])
next_char_index = len(intervals_lower) - 1
determined = True
bottom = character_intervals[next_char_index]
top = character_intervals[next_char_index+1]
while bottom >= max_val // 2 or top < max_val // 2:
message_position += 1
if bottom >= max_val // 2:
bottom -= (max_val // 2)
top -= (max_val // 2)
max_val = max_val // 2
elif top < max_val // 2:
max_val = max_val // 2
model.next_char(text[-1])
return "".join(text[:-1])