Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Load Model: Use paths relative to workflow file #4534

Merged
merged 1 commit into from
Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 39 additions & 84 deletions Orange/widgets/model/owloadmodel.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
import os
import pickle

from AnyQt.QtWidgets import QSizePolicy, QStyle, QFileDialog
from AnyQt.QtCore import QTimer
from AnyQt.QtWidgets import (
QSizePolicy, QHBoxLayout, QComboBox, QStyle, QFileDialog
)

from Orange.base import Model
from Orange.widgets import widget, gui
from Orange.widgets.model import owsavemodel
from Orange.widgets.settings import Setting
from Orange.widgets.utils.filedialogs import RecentPathsWComboMixin
from Orange.widgets.utils import stdpaths
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import Msg, Output


class OWLoadModel(widget.OWWidget):
class OWLoadModel(widget.OWWidget, RecentPathsWComboMixin):
name = "Load Model"
description = "Load a model from an input file."
priority = 3050
replaces = ["Orange.widgets.classify.owloadclassifier.OWLoadClassifier"]
icon = "icons/LoadModel.svg"
keywords = ["file", "open"]
keywords = ["file", "open", "model"]

class Outputs:
model = Output("Model", Model)

#: List of recent filenames.
history = Setting([])
#: Current (last selected) filename or None.
filename = Setting(None)

class Error(widget.OWWidget.Error):
load_error = Msg("An error occured while reading '{}'")

Expand All @@ -41,96 +34,58 @@ class Error(widget.OWWidget.Error):

def __init__(self):
super().__init__()
self.selectedIndex = -1

box = gui.widgetBox(
self.controlArea, self.tr("File"), orientation=QHBoxLayout()
)
RecentPathsWComboMixin.__init__(self)
self.loaded_file = ""

self.filesCB = gui.comboBox(
box, self, "selectedIndex", callback=self._on_recent)
self.filesCB.setMinimumContentsLength(20)
self.filesCB.setSizeAdjustPolicy(
QComboBox.AdjustToMinimumContentsLength)
vbox = gui.vBox(self.controlArea, "File", addSpace=True)
box = gui.hBox(vbox)
self.file_combo.setMinimumWidth(300)
box.layout().addWidget(self.file_combo)
self.file_combo.activated[int].connect(self.select_file)

self.loadbutton = gui.button(box, self, "...", callback=self.browse)
self.loadbutton.setIcon(
self.style().standardIcon(QStyle.SP_DirOpenIcon))
self.loadbutton.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed)
button = gui.button(box, self, '...', callback=self.browse_file)
button.setIcon(self.style().standardIcon(QStyle.SP_DirOpenIcon))
button.setSizePolicy(
QSizePolicy.Maximum, QSizePolicy.Fixed)

self.reloadbutton = gui.button(
button = gui.button(
box, self, "Reload", callback=self.reload, default=True)
self.reloadbutton.setIcon(
self.style().standardIcon(QStyle.SP_BrowserReload))
self.reloadbutton.setSizePolicy(QSizePolicy.Maximum,
QSizePolicy.Fixed)

# filter valid existing filenames
self.history = list(filter(os.path.isfile, self.history))[:20]
for filename in self.history:
self.filesCB.addItem(os.path.basename(filename), userData=filename)

# restore the current selection if the filename is
# in the history list
if self.filename in self.history:
self.selectedIndex = self.history.index(self.filename)
else:
self.selectedIndex = -1
self.filename = None
self.reloadbutton.setEnabled(False)
button.setIcon(self.style().standardIcon(QStyle.SP_BrowserReload))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)

if self.filename:
QTimer.singleShot(0, lambda: self.load(self.filename))

def browse(self):
"""Select a filename using an open file dialog."""
if self.filename is None:
startdir = stdpaths.Documents
else:
startdir = os.path.dirname(self.filename)
self.set_file_list()
QTimer.singleShot(0, self.open_file)

def browse_file(self):
start_file = self.last_path() or stdpaths.Documents
filename, _ = QFileDialog.getOpenFileName(
self, self.tr("Open"), directory=startdir, filter=self.FILTER)
self, 'Open Distance File', start_file, self.FILTER)
if not filename:
return
self.add_path(filename)
self.open_file()

if filename:
self.load(filename)
def select_file(self, n):
super().select_file(n)
self.open_file()

def reload(self):
"""Reload the current file."""
self.load(self.filename)
self.open_file()

def load(self, filename):
"""Load the object from filename and send it to output."""
def open_file(self):
self.clear_messages()
fn = self.last_path()
if not fn:
return
try:
with open(filename, "rb") as f:
with open(fn, "rb") as f:
model = pickle.load(f)
except (pickle.UnpicklingError, OSError, EOFError):
self.Error.load_error(os.path.split(filename)[-1])
self.Error.load_error(os.path.split(fn)[-1])
self.Outputs.model.send(None)
else:
self.Error.load_error.clear()
self._remember(filename)
self.Outputs.model.send(model)

def _remember(self, filename):
"""
Remember `filename` was accessed.
"""
if filename in self.history:
index = self.history.index(filename)
del self.history[index]
self.filesCB.removeItem(index)

self.history.insert(0, filename)

self.filesCB.insertItem(0, os.path.basename(filename),
userData=filename)
self.selectedIndex = 0
self.filename = filename
self.reloadbutton.setEnabled(self.selectedIndex != -1)

def _on_recent(self):
self.load(self.history[self.selectedIndex])


if __name__ == "__main__": # pragma: no cover
WidgetPreview(OWLoadModel).run()
144 changes: 125 additions & 19 deletions Orange/widgets/model/tests/test_owloadmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,139 @@
# pylint: disable=missing-docstring
import os
import pickle
from tempfile import mkstemp
from tempfile import NamedTemporaryFile
import unittest
from unittest.mock import Mock, patch

from Orange.classification.majority import ConstantModel
import numpy as np

from orangewidget.utils.filedialogs import RecentPath
from Orange.data import Table
from Orange.classification.naive_bayes import NaiveBayesLearner
from Orange.widgets.model.owloadmodel import OWLoadModel
from Orange.widgets.tests.base import WidgetTest


class TestOWLoadModel(WidgetTest):
# Attribute used to store event data so it does not get garbage
# collected before event is processed.
event_data = None

def setUp(self):
self.widget = self.create_widget(OWLoadModel)
self.widget = self.create_widget(OWLoadModel) # type: OWLoadModel
data = Table("iris")
self.model = NaiveBayesLearner()(data)
with NamedTemporaryFile(suffix=".pkcls", delete=False) as f:
self.filename = f.name
pickle.dump(self.model, f)

def tearDown(self):
os.remove(self.filename)

def test_browse_file_opens_file(self):
w = self.widget
with patch("AnyQt.QtWidgets.QFileDialog.getOpenFileName",
Mock(return_value=(self.filename, "*.pkcls"))):
w.browse_file()
model = self.get_output(w.Outputs.model)
np.testing.assert_equal(
model.log_cont_prob, self.model.log_cont_prob)

with patch("AnyQt.QtWidgets.QFileDialog.getOpenFileName",
Mock(return_value=("", "*.pkcls"))):
w.browse_file()
# Keep the same model on output
model2 = self.get_output(w.Outputs.model)
self.assertIs(model2, model)

with patch("AnyQt.QtWidgets.QFileDialog.getOpenFileName",
Mock(return_value=(self.filename, "*.pkcls"))):
w.reload()
model2 = self.get_output(w.Outputs.model)
self.assertIsNot(model2, model)

@patch("pickle.load")
def test_select_file(self, load):
w = self.widget
with NamedTemporaryFile(suffix=".pkcls") as f2, \
NamedTemporaryFile(suffix=".pkcls", delete=False) as f3:
w.add_path(self.filename)
w.add_path(f2.name)
w.add_path(f3.name)
w.open_file()
args = load.call_args[0][0]
self.assertEqual(args.name, f3.name.replace("\\", "/"))
w.select_file(2)
args = load.call_args[0][0]
self.assertEqual(args.name, self.filename.replace("\\", "/"))

def test_show_error(self):
self.widget.load("no-such-file.pckls")
self.assertTrue(self.widget.Error.load_error.is_shown())
def test_load_error(self):
w = self.widget
with patch("AnyQt.QtWidgets.QFileDialog.getOpenFileName",
Mock(return_value=(self.filename, "*.pkcls"))):
with patch("pickle.load", side_effect=pickle.UnpicklingError):
w.browse_file()
self.assertTrue(w.Error.load_error.is_shown())
self.assertIsNone(self.get_output(w.Outputs.model))

clsf = ConstantModel([1, 1, 1])
fd, fname = mkstemp(suffix='.pkcls')
os.close(fd)
w.reload()
self.assertFalse(w.Error.load_error.is_shown())
model = self.get_output(w.Outputs.model)
self.assertIsNotNone(model)

with patch.object(w, "last_path", Mock(return_value="")), \
patch("pickle.load") as load:
w.reload()
load.assert_not_called()
self.assertFalse(w.Error.load_error.is_shown())
self.assertIs(self.get_output(w.Outputs.model), model)

with patch("pickle.load", side_effect=pickle.UnpicklingError):
w.reload()
self.assertTrue(w.Error.load_error.is_shown())
self.assertIsNone(self.get_output(w.Outputs.model))

with patch("AnyQt.QtWidgets.QFileDialog.getOpenFileName",
Mock(return_value=("foo", "*.pkcls"))):
w.browse_file()
self.assertTrue(w.Error.load_error.is_shown())
self.assertIsNone(self.get_output(w.Outputs.model))

def test_no_last_path(self):
self.widget = \
self.create_widget(OWLoadModel,
stored_settings={"recent_paths": []})
# Doesn't crash and contains a single item, (none).
self.assertEqual(self.widget.file_combo.count(), 1)

@patch("Orange.widgets.widget.OWWidget.workflowEnv",
Mock(return_value={"basedir": os.getcwd()}))
@patch("pickle.load")
def test_open_moved_workflow(self, load):
"""
Test opening workflow that has been moved to another location
(i.e. sent by email), considering data file is stored in the same
directory as the workflow.
"""
temp_file = NamedTemporaryFile(dir=os.getcwd(), delete=False)
file_name = temp_file.name
temp_file.close()
base_name = os.path.basename(file_name)
try:
with open(fname, 'wb') as f:
pickle.dump(clsf, f)
self.widget.load(fname)
self.assertFalse(self.widget.Error.load_error.is_shown())

with open(fname, "w") as f:
f.write("X")
self.widget.load(fname)
self.assertTrue(self.widget.Error.load_error.is_shown())
recent_path = RecentPath(
os.path.join("temp/models", base_name), "",
os.path.join("models", base_name)
)
stored_settings = {"recent_paths": [recent_path]}
w = self.create_widget(OWLoadModel,
stored_settings=stored_settings)
w.open_file()
self.assertEqual(w.file_combo.count(), 1)
args = load.call_args[0][0]
self.assertEqual(args.name, file_name.replace("\\", "/"))
finally:
os.remove(fname)
os.remove(file_name)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion doc/widgets.json
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@
"background": "#FAC1D9",
"keywords": [
"file",
"open"
"open",
"model"
]
}
]
Expand Down