Source code for inspirehep.modules.workflows.utils

# -*- coding: utf-8 -*-
#
# This file is part of INSPIRE.
# Copyright (C) 2014-2017 CERN.
#
# INSPIRE is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# INSPIRE is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with INSPIRE. If not, see <http://www.gnu.org/licenses/>.
#
# In applying this license, CERN does not waive the privileges and immunities
# granted to it by virtue of its status as an Intergovernmental Organization
# or submit itself to any jurisdiction.

"""Workflows utils."""

from __future__ import absolute_import, division, print_function

import json
import os
import traceback
from contextlib import closing, contextmanager
from functools import wraps
from six import text_type
from six.moves.urllib.parse import unquote
import backoff
import lxml.etree as ET
import requests
from flask import current_app, url_for
from timeout_decorator import timeout, TimeoutError
from fs.opener import fsopen
from invenio_db import db

from inspire_schemas.utils import \
    get_validation_errors as _get_validation_errors
from inspire_utils.logging import getStackTraceLogger

from inspirehep.utils.url import retrieve_uri
from inspirehep.modules.workflows.models import (
    WorkflowsAudit,
    WorkflowsRecordSources,
)


LOGGER = getStackTraceLogger(__name__)


@backoff.on_exception(backoff.expo, requests.packages.urllib3.exceptions.ConnectionError, base=4, max_tries=5)
[docs]def json_api_request(url, data, headers=None): """Make JSON API request and return JSON response.""" final_headers = { "Content-Type": "application/json", "Accept": "application/json" } if headers: final_headers.update(headers) current_app.logger.debug("POST {0} with \n{1}".format( url, json.dumps(data, indent=4) )) try: response = requests.post( url=url, headers=final_headers, data=json.dumps(data), ) except requests.exceptions.RequestException as err: current_app.logger.exception(err) raise if response.status_code == 200: return response.json()
[docs]def log_workflows_action(action, relevance_prediction, object_id, user_id, source, user_action=""): """Log the action taken by user compared to a prediction.""" if relevance_prediction: score = relevance_prediction.get("max_score") # returns 0.222113 decision = relevance_prediction.get("decision") # returns "Rejected" # Map actions to align with the prediction format action_map = { 'accept': 'Non-CORE', 'accept_core': 'CORE', 'reject': 'Rejected' } logging_info = { 'object_id': object_id, 'user_id': user_id, 'score': score, 'user_action': action_map.get(user_action, ""), 'decision': decision, 'source': source, 'action': action } audit = WorkflowsAudit(**logging_info) audit.save()
[docs]def with_debug_logging(func): """Generate a debug log with info on what's going to run. It tries its best to use the logging facilities of the object passed or the application context before falling back to the python logging facility. """ @wraps(func) def _decorator(*args, **kwargs): def _get_obj(args, kwargs): if args: obj = args[0] else: obj = kwargs.get('obj', kwargs.get('record')) return obj def _get_logfn(args, kwargs): obj = _get_obj(args, kwargs) if hasattr(obj, 'log') and hasattr(obj.log, 'debug'): logfn = obj.log.debug elif hasattr(current_app, 'logger'): logfn = current_app.logger.debug else: logfn = LOGGER.debug return logfn def _try_to_log(logfn, *args, **kwargs): try: logfn(*args, **kwargs) except Exception: LOGGER.debug( 'Error while trying to log with %s:\n%s', logfn, traceback.format_exc() ) logfn = _get_logfn(args, kwargs) _try_to_log(logfn, 'Starting %s', func) res = func(*args, **kwargs) _try_to_log( logfn, "Finished %s with (single quoted) result '%s'", func, res, ) return res return _decorator
[docs]def do_not_repeat(step_id): """Decorator used to skip workflow steps when a workflow is re-run. Will store the result of running the workflow step in source_data.persistent_data after running the first time, and skip the step on the following runs, also applying previously recorded 'changes' to extra_data. The decorated function has to conform to the following signature: def decorated_step(obj: WorkflowObject, eng: WorkflowEngine) -> Dict[str, Any]: ... Where obj and eng are usual arguments following the protocol of all workflow steps. The returned value of the decorated_step will be used as a patch to be applied on the workflow object's source data (which 'replays' changes made by the workflow step). Args: step_id (str): name of the workflow step, to be used as key in persistent_data Returns: callable: the decorator """ def decorator(func): @wraps(func) def _do_not_repeat(obj, eng): source_data = obj.extra_data['source_data'] is_task_repeated = step_id in obj.extra_data['source_data'].setdefault('persistent_data', {}) if is_task_repeated: extra_data_update = source_data['persistent_data'][step_id] obj.extra_data.update(extra_data_update) obj.save() return return_value = func(obj, eng) if not isinstance(return_value, dict): raise TypeError( "Functions decorated by 'do_not_repeat' must return a " "dictionary compliant to extra_data info" ) source_data['persistent_data'][step_id] = return_value obj.save() return return_value return _do_not_repeat return decorator
[docs]def ignore_timeout_error(return_value=None): """Ignore the TimeoutError, returning return_value when it happens. Quick fix for ``refextract`` and ``plotextract`` tasks only. It shouldn't be used for others! """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except TimeoutError: LOGGER.error( 'Timeout error while extracting raised from: %s.', func.__name__ ) return return_value return wrapper return decorator
[docs]def timeout_with_config(config_key): """Decorator to set a configurable timeout on a function. Args: config_key (str): config key with a integer value representing the time in seconds after which the decorated function will abort, raising a ``TimeoutError``. If the key is not present in the config, a ``KeyError`` is raised. Note: This function is needed because it's impossible to pass a value read from the config as an argument to a decorator, as it gets evaluated before the application context is set up. """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): timeout_time = current_app.config[config_key] return timeout(timeout_time)(func)(*args, **kwargs) return wrapper return decorator
@contextmanager @with_debug_logging
[docs]def get_document_in_workflow(obj): """Context manager giving the path to the document attached to a workflow object. Arg: obj: workflow object Returns: Optional[str]: The path to a local copy of the document. If no documents are present, it retuns None. If several documents are present, it prioritizes the fulltext. If several documents with the same priority are present, it takes the first one and logs an error. """ documents = obj.data.get('documents', []) fulltexts = [document for document in documents if document.get('fulltext')] documents = fulltexts or documents if not documents: obj.log.info('No document available') yield None return elif len(documents) > 1: obj.log.error('More than one document in workflow, first one used') key = documents[0]['key'] obj.log.info('Using document with key "%s"', key) with retrieve_uri(obj.files[key].file.uri) as local_file: yield local_file
@with_debug_logging
[docs]def copy_file_to_workflow(workflow, name, url): url = unquote(url) stream = fsopen(url, mode='rb') workflow.files[name] = stream return workflow.files[name]
@backoff.on_exception(backoff.expo, requests.packages.urllib3.exceptions.ProtocolError, max_tries=5)
[docs]def download_file_to_workflow(workflow, name, url): """Download a file to a specified workflow. The ``workflow.files`` property is actually a method, which returns a ``WorkflowFilesIterator``. This class inherits a custom ``__setitem__`` method from its parent, ``FilesIterator``, which ends up calling ``save`` on an ``invenio_files_rest.storage.pyfs.PyFSFileStorage`` instance through ``ObjectVersion`` and ``FileObject``. This method consumes the stream passed to it and saves in its place a ``FileObject`` with the details of the downloaded file. Consuming the stream might raise a ``ProtocolError`` because the server might terminate the connection before sending any data. In this case we retry 5 times with exponential backoff before giving up. """ with closing(requests.get(url=url, stream=True)) as req: if req.status_code == 200: req.raw.decode_content = True workflow.files[name] = req.raw return workflow.files[name]
[docs]def convert(xml, xslt_filename): """Convert XML using given XSLT stylesheet.""" if not os.path.isabs(xslt_filename): prefix_dir = os.path.dirname(os.path.realpath(__file__)) xslt_filename = os.path.join(prefix_dir, "stylesheets", xslt_filename) dom = ET.fromstring(xml) xslt = ET.parse(xslt_filename) transform = ET.XSLT(xslt) newdom = transform(dom) return ET.tostring(newdom, pretty_print=False)
[docs]def read_wf_record_source(record_uuid, source): """Retrieve a record from the ``WorkflowRecordSource`` table. Args: record_uuid(uuid): the uuid of the record source(string): the acquisition source value of the record Return: (dict): the given record, if any or None """ if not source: return source = get_source_for_root(source) entry = WorkflowsRecordSources.query.filter_by( record_uuid=str(record_uuid), source=source.lower(), ).one_or_none() return entry
[docs]def read_all_wf_record_sources(record_uuid): """Retrieve all ``WorkflowRecordSource`` for a given record id. Args: record_uuid(uuid): the uuid of the record Return: (list): the ``WorkflowRecordSource``s related to ``record_uuid`` """ entries = list(WorkflowsRecordSources.query.filter_by(record_uuid=str(record_uuid))) return entries
[docs]def insert_wf_record_source(json, record_uuid, source): """Stores a record in the WorkflowRecordSource table in the db. Args: json(dict): the record's content to store record_uuid(uuid): the record's uuid source(string): the source of the record """ if not source: return source = get_source_for_root(source) record_source = read_wf_record_source( record_uuid=record_uuid, source=source) if record_source is None: record_source = WorkflowsRecordSources( source=source.lower(), json=json, record_uuid=record_uuid, ) db.session.add(record_source) else: record_source.json = json db.session.commit()
[docs]def get_source_for_root(source): """Source for the root workflow object. Args: source(str): the record source. Return: (str): the source for the root workflow object. Note: For the time being any workflow with ``acquisition_source.source`` different than ``arxiv`` and ``submitter`` will be stored as ``publisher``. """ return source if source in ['arxiv', 'submitter'] else 'publisher'
[docs]def get_resolve_validation_callback_url(): """Resolve validation callback. Returns the callback url for resolving the validation errors. Note: It's using ``inspire_workflows.callback_resolve_validation`` route. """ return url_for( 'inspire_workflows_callbacks.callback_resolve_validation', _external=True )
[docs]def get_resolve_merge_conflicts_callback_url(): """Resolve validation callback. Returns the callback url for resolving the merge conflicts. Note: It's using ``inspire_workflows.callback_resolve_merge_conflicts`` route. """ return url_for( 'inspire_workflows_callbacks.callback_resolve_merge_conflicts', _external=True )
[docs]def get_resolve_edit_article_callback_url(): """Resolve edit_article workflow letting it continue. Note: It's using ``inspire_workflows.callback_resolve_edit_article`` route. """ return url_for( 'inspire_workflows_callbacks.callback_resolve_edit_article', _external=True )
[docs]def get_validation_errors(data, schema): """Creates a ``validation_errors`` dictionary. Args: data (dict): the object to validate. schema (str): the name of the schema. Returns: dict: ``validation_errors`` formatted dict. """ errors = _get_validation_errors(data, schema=schema) error_messages = [ { 'path': map(text_type, error.absolute_path), 'message': text_type(error.message), } for error in errors ] return error_messages