Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option --custom to perform test with custom servers #784

Open
wants to merge 4 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ pip-log.txt

#Mr Developer
.mr.developer.cfg

#Environment
speedtest-env
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,8 @@ def find_version(*file_paths):
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
]
)
161 changes: 133 additions & 28 deletions speedtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,30 @@
# License for the specific language governing permissions and limitations
# under the License.

import os
import re
import csv
import sys
import math
import datetime
import errno
import math
import os
import platform
import re
import signal
import socket
import timeit
import datetime
import platform
import sys
import threading
import timeit
import xml.parsers.expat

import requests

try:
import gzip
GZIP_BASE = gzip.GzipFile
except ImportError:
gzip = None
GZIP_BASE = object

__version__ = '2.1.3'
__version__ = '2.1.4b1'


class FakeShutdownEvent(object):
Expand All @@ -49,13 +51,16 @@ def isSet():
"Dummy method to always return false"""
return False

is_set = isSet


# Some global variables we use
DEBUG = False
_GLOBAL_DEFAULT_TIMEOUT = object()
PY25PLUS = sys.version_info[:2] >= (2, 5)
PY26PLUS = sys.version_info[:2] >= (2, 6)
PY32PLUS = sys.version_info[:2] >= (3, 2)
PY310PLUS = sys.version_info[:2] >= (3, 10)

# Begin import game to handle Python 2 and Python 3
try:
Expand Down Expand Up @@ -266,17 +271,6 @@ def write(data):
write(arg)
write(end)

if PY32PLUS:
etree_iter = ET.Element.iter
elif PY25PLUS:
etree_iter = ET_Element.getiterator

if PY26PLUS:
thread_is_alive = threading.Thread.is_alive
else:
thread_is_alive = threading.Thread.isAlive


# Exception "constants" to support Python 2 through Python 3
try:
import ssl
Expand All @@ -293,6 +287,23 @@ def write(data):
ssl = None
HTTP_ERRORS = (HTTPError, URLError, socket.error, BadStatusLine)

if PY32PLUS:
etree_iter = ET.Element.iter
elif PY25PLUS:
etree_iter = ET_Element.getiterator

if PY26PLUS:
thread_is_alive = threading.Thread.is_alive
else:
thread_is_alive = threading.Thread.isAlive


def event_is_set(event):
try:
return event.is_set()
except AttributeError:
return event.isSet()


class SpeedtestException(Exception):
"""Base exception for this module"""
Expand All @@ -311,7 +322,7 @@ class SpeedtestConfigError(SpeedtestException):


class SpeedtestServersError(SpeedtestException):
"""Servers XML is invalid"""
"""Servers XML or JSON is invalid"""


class ConfigRetrievalError(SpeedtestHTTPError):
Expand Down Expand Up @@ -769,7 +780,7 @@ def print_dots(shutdown_event):
status
"""
def inner(current, total, start=False, end=False):
if shutdown_event.isSet():
if event_is_set(shutdown_event):
return

sys.stdout.write('.')
Expand Down Expand Up @@ -808,7 +819,7 @@ def run(self):
try:
if (timeit.default_timer() - self.starttime) <= self.timeout:
f = self._opener(self.request)
while (not self._shutdown_event.isSet() and
while (not event_is_set(self._shutdown_event) and
(timeit.default_timer() - self.starttime) <=
self.timeout):
self.result.append(len(f.read(10240)))
Expand Down Expand Up @@ -864,7 +875,7 @@ def data(self):

def read(self, n=10240):
if ((timeit.default_timer() - self.start) <= self.timeout and
not self._shutdown_event.isSet()):
not event_is_set(self._shutdown_event)):
chunk = self.data.read(n)
self.total.append(len(chunk))
return chunk
Expand Down Expand Up @@ -902,7 +913,7 @@ def run(self):
request = self.request
try:
if ((timeit.default_timer() - self.starttime) <= self.timeout and
not self._shutdown_event.isSet()):
not event_is_set(self._shutdown_event)):
try:
f = self._opener(request)
except TypeError:
Expand Down Expand Up @@ -1228,7 +1239,38 @@ def get_config(self):

return self.config

def get_servers(self, servers=None, exclude=None):
def json_to_xml(self,data=None, server_id_list=None):
"""Converts text data representing a link with json or json text to XML"""
if data:
try:
r = requests.get(data)
except requests.exceptions.MissingSchema:
raise SpeedtestServersError("Invalid --custom link")
if r.status_code == 200:
message = '<?xml version="1.0" encoding="UTF-8"?>\n<settings>\n<servers>'
try:
json_data = json.loads(r.text)
if server_id_list and len(server_id_list)>=1:
for server_json in json_data:
if int(server_json["id"]) in server_id_list:
json_data = server_json
try:
message += f'<server url="{json_data["url"]}" lat="{json_data["lat"]}" lon="{json_data["lon"]}" name="{json_data["name"]}" country="{json_data["country"]}" cc="{json_data["cc"]}" sponsor="{json_data["sponsor"]}" id="{json_data["id"]}" host="{json_data["host"]}" />'
except (KeyError,SyntaxError) as e:
pass
else:
json_data = json_data[0]
try:
message += f'<server url="{json_data["url"]}" lat="{json_data["lat"]}" lon="{json_data["lon"]}" name="{json_data["name"]}" country="{json_data["country"]}" cc="{json_data["cc"]}" sponsor="{json_data["sponsor"]}" id="{json_data["id"]}" host="{json_data["host"]}" />'
except (KeyError,SyntaxError) as e:
pass
except json.decoder.JSONDecodeError:
raise SpeedtestServersError("Invalid json data provided by the link")
message += "\n</servers>\n</settings>\n"
return message.replace("&","").replace("%","").encode()


def get_servers(self, servers=None, exclude=None, custom_server=None):
"""Retrieve a the list of speedtest.net servers, optionally filtered
to servers matching those specified in the ``servers`` argument
"""
Expand Down Expand Up @@ -1261,6 +1303,68 @@ def get_servers(self, servers=None, exclude=None):
headers['Accept-Encoding'] = 'gzip'

errors = []
if custom_server:
if custom_server and servers:
serversxml = "".encode().join([self.json_to_xml(custom_server,servers)])
else:
serversxml = "".encode().join([self.json_to_xml(custom_server)])
try:
try:
try:
root = ET.fromstring(serversxml)
except ET.ParseError:
e = get_exception()
raise SpeedtestServersError(
'Malformed speedtest.net server list: %s' % e
)
elements = etree_iter(root, 'server')
except AttributeError:
try:
root = DOM.parseString(serversxml)
except ExpatError:
e = get_exception()
raise SpeedtestServersError(
'Malformed speedtest.net server list: %s' % e
)
elements = root.getElementsByTagName('server')
except (SyntaxError, xml.parsers.expat.ExpatError):
raise ServersRetrievalError()

for server in elements:
try:
attrib = server.attrib
except AttributeError:
attrib = dict(list(server.attributes.items()))

if servers and int(attrib.get('id')) not in servers:
continue

if (int(attrib.get('id')) in self.config['ignore_servers']
or int(attrib.get('id')) in exclude):
continue

try:
d = distance(self.lat_lon,
(float(attrib.get('lat')),
float(attrib.get('lon'))))
except Exception:
continue

attrib['d'] = d

try:
self.servers[d].append(attrib)
except KeyError:
self.servers[d] = [attrib]

except ServersRetrievalError:
pass

if (servers or exclude) and not self.servers:
raise NoMatchedServers()

return self.servers

for url in urls:
try:
request = build_request(
Expand Down Expand Up @@ -1290,9 +1394,7 @@ def get_servers(self, servers=None, exclude=None):

if int(uh.code) != 200:
raise ServersRetrievalError()

serversxml = ''.encode().join(serversxml_list)

printer('Servers XML:\n%s' % serversxml, debug=True)

try:
Expand Down Expand Up @@ -1775,6 +1877,7 @@ def parse_args():
help='Show the version number and exit')
parser.add_argument('--debug', action='store_true',
help=ARG_SUPPRESS, default=ARG_SUPPRESS)
parser.add_argument('--custom',help="Test with a custom server using its link")

options = parser.parse_args()
if isinstance(options, tuple):
Expand Down Expand Up @@ -1907,7 +2010,9 @@ def shell():
if not args.mini:
printer('Retrieving speedtest.net server list...', quiet)
try:
speedtest.get_servers(servers=args.server, exclude=args.exclude)
speedtest.get_servers(servers=args.server,
exclude=args.exclude,
custom_server=args.custom)
except NoMatchedServers:
raise SpeedtestCLIError(
'No matched servers: %s' %
Expand Down