Skip to content

Commit

Permalink
Merge pull request #867 from PrimozGodec/fix-proxies
Browse files Browse the repository at this point in the history
Fix proxy addresses for embedders and NLTK
  • Loading branch information
lanzagar authored Jun 20, 2022
2 parents c05b4fb + b2b3e4d commit eecdb90
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 9 deletions.
37 changes: 37 additions & 0 deletions orangecontrib/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,42 @@
# Set where NLTK data is downloaded
import os

# temporary solution - remove when Orange 3.33 is released
# it must be imported before nltk_data_dir
from typing import Optional, Dict
from Orange.misc.utils import embedder_utils


def _get_proxies() -> Optional[Dict[str, str]]:
"""
Return dict with proxy addresses if they exist.
Returns
-------
proxy_dict
Dictionary with format {proxy type: proxy address} or None if
they not set.
"""
def add_scheme(url: Optional[str]) -> Optional[str]:
if url is not None and "://" not in url:
# if no scheme default to http - as other libraries do (e.g. requests)
return f"http://{url}"
else:
return url

http_proxy = add_scheme(os.environ.get("http_proxy"))
https_proxy = add_scheme(os.environ.get("https_proxy"))
proxy_dict = {}
if http_proxy:
proxy_dict["http://"] = http_proxy
if https_proxy:
proxy_dict["https://"] = https_proxy
return proxy_dict if proxy_dict else None


embedder_utils.get_proxies = _get_proxies
# remove to here


from orangecontrib.text.misc import nltk_data_dir
os.environ['NLTK_DATA'] = nltk_data_dir()

Expand Down
45 changes: 36 additions & 9 deletions orangecontrib/text/misc/nltk_data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from functools import wraps
from threading import Thread
from urllib.parse import urlparse, ParseResult

import nltk
from Orange.misc.environ import data_dir_base
Expand Down Expand Up @@ -38,18 +39,44 @@ def nltk_data_dir():
is_done_loading = False


def _download_nltk_data():
global is_done_loading
# for any other potential scheme, it should be provided by user
DEFAULT_PORTS = {
"http": "80",
"https": "443",
"socks4": "1080",
"socks": "1080",
"quic": "443",
}

# set proxy if exist

def _get_proxy_address():
"""
Set proxy addresses for NLTK since NLTK do not use proxy addresses from
https_proxy environment variable
"""
proxies = get_proxies() or {}
# use https if exists and others otherwise
for key in ("https://", "all://", "http://"):
if key in proxies:
log.debug(f"Using proxy for NLTK: {proxies[key]}")
nltk.set_proxy(proxies[key])
break
# nltk uses https to download data
if "https://" in proxies:
proxy = urlparse(proxies['https://'])
log.debug(f"Using proxy for NLTK: {proxy}")
port = proxy.port or DEFAULT_PORTS.get(proxy.scheme)
url = ParseResult(
scheme=proxy.scheme,
netloc="{}:{}".format(proxy.hostname, port) if port else proxy.netloc,
path=proxy.path,
params=proxy.params,
query=proxy.query,
fragment=proxy.fragment
).geturl()
return url


def _download_nltk_data():
global is_done_loading

proxy_address = _get_proxy_address()
if proxy_address:
nltk.set_proxy(proxy_address)
nltk.download(NLTK_DATA, download_dir=nltk_data_dir(), quiet=True)
is_done_loading = True
sys.stdout.flush()
Expand Down
10 changes: 10 additions & 0 deletions orangecontrib/text/tests/test_documentembedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def test_invalid_parameters(self):
with self.assertRaises(ValueError):
self.embedder = DocumentEmbedder(aggregator='average')

def test_remove_temporary_proxy_solution(self):
"""
When it starts to fail:
- remove this test
- remove temporary implementation of get_proxy() function in text.__inint__
- set minimum version of Orange on 3.33
"""
import Orange
self.assertGreater("3.34.0", Orange.__version__)


if __name__ == "__main__":
unittest.main()
36 changes: 36 additions & 0 deletions orangecontrib/text/tests/test_nltk_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import unittest
from orangecontrib.text.misc.nltk_data_download import _get_proxy_address


class TestNLTKDownload(unittest.TestCase):
def setUp(self) -> None:
self.previous_https = os.environ.get("https_proxy")
os.environ.pop("https_proxy", None)

def tearDown(self) -> None:
os.environ.pop("https_proxy", None)
if self.previous_https is not None:
os.environ["https_proxy"] = self.previous_https

def test_get_proxy_address(self):
self.assertIsNone(_get_proxy_address())

os.environ["https_proxy"] = "https://test.com"
self.assertEqual("https://test.com:443", _get_proxy_address())

os.environ["https_proxy"] = "https://test.com:12"
self.assertEqual("https://test.com:12", _get_proxy_address())

os.environ["https_proxy"] = "https://test.com/test"
self.assertEqual("https://test.com:443/test", _get_proxy_address())

os.environ["https_proxy"] = "https://test.com/test?a=2"
self.assertEqual("https://test.com:443/test?a=2", _get_proxy_address())

os.environ["https_proxy"] = "test.com/test?a=2"
self.assertEqual("http://test.com:80/test?a=2", _get_proxy_address())


if __name__ == "__main__":
unittest.main()

0 comments on commit eecdb90

Please sign in to comment.