New Research: Supply Chain Attack on Axios Pulls Malicious Dependency from npm.Details →
Socket
Book a DemoSign in
Socket

python-taint

Package Overview
Dependencies
Maintainers
3
Versions
7
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

python-taint - pypi Package Compare versions

Comparing version
0.34
to
0.37
pyt/analysis/__init__.py
+19
"""Global lookup table for constraints.
Uses cfg node as key and operates on bitvectors in the form of ints."""
constraint_table = dict()
def initialize_constraint_table(cfg_list):
"""Collects all given cfg nodes and initializes the table with value 0."""
for cfg in cfg_list:
constraint_table.update(dict.fromkeys(cfg.nodes, 0))
def constraint_join(cfg_nodes):
"""Looks up all cfg_nodes and joins the bitvectors by using logical or."""
r = 0
for e in cfg_nodes:
r = r | constraint_table[e]
return r
from collections import defaultdict
from .constraint_table import constraint_table
from ..core.node_types import AssignmentNode
def get_constraint_nodes(
node,
lattice
):
for n in lattice.get_elements(constraint_table[node]):
if n is not node:
yield n
def build_def_use_chain(
cfg_nodes,
lattice
):
def_use = defaultdict(list)
# For every node
for node in cfg_nodes:
# That's a definition
if isinstance(node, AssignmentNode):
# Get the uses
for variable in node.right_hand_side_variables:
# Loop through most of the nodes before it
for earlier_node in get_constraint_nodes(node, lattice):
# and add them to the 'uses list' of each earlier node, when applicable
# 'earlier node' here being a simplification
if variable in earlier_node.left_hand_side:
def_use[earlier_node].append(node)
return def_use
"""This module implements the fixed point algorithm."""
from .constraint_table import constraint_table
from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis
class FixedPointAnalysis():
"""Run the fix point analysis."""
def __init__(self, cfg):
"""Fixed point analysis.
Analysis must be a dataflow analysis containing a 'fixpointmethod'
method that analyses one CFG."""
self.analysis = ReachingDefinitionsTaintAnalysis(cfg)
self.cfg = cfg
def fixpoint_runner(self):
"""Work list algorithm that runs the fixpoint algorithm."""
q = self.cfg.nodes
while q != []:
x_i = constraint_table[q[0]] # x_i = q[0].old_constraint
self.analysis.fixpointmethod(q[0]) # y = F_i(x_1, ..., x_n);
y = constraint_table[q[0]] # y = q[0].new_constraint
if y != x_i:
for node in self.analysis.dep(q[0]): # for (v in dep(v_i))
q.append(node) # q.append(v):
constraint_table[q[0]] = y # q[0].old_constraint = q[0].new_constraint # x_i = y
q = q[1:] # q = q.tail() # The list minus the head
def analyse(cfg_list):
"""Analyse a list of control flow graphs with a given analysis type."""
for cfg in cfg_list:
analysis = FixedPointAnalysis(cfg)
analysis.fixpoint_runner()
from .constraint_table import constraint_table
from ..core.node_types import AssignmentNode
def get_lattice_elements(cfg_nodes):
"""Returns all assignment nodes as they are the only lattice elements
in the reaching definitions analysis.
"""
for node in cfg_nodes:
if isinstance(node, AssignmentNode):
yield node
class Lattice:
def __init__(self, cfg_nodes):
self.el2bv = dict() # Element to bitvector dictionary
self.bv2el = list() # Bitvector to element list
for i, e in enumerate(get_lattice_elements(cfg_nodes)):
# Give each element a unique shift of 1
self.el2bv[e] = 0b1 << i
self.bv2el.insert(0, e)
def get_elements(self, number):
if number == 0:
return []
elements = list()
# Turn number into a binary string of length len(self.bv2el)
binary_string = format(number,
'0' + str(len(self.bv2el)) + 'b')
for i, bit in enumerate(binary_string):
if bit == '1':
elements.append(self.bv2el[i])
return elements
def in_constraint(self, node1, node2):
"""Checks if node1 is in node2's constraints
For instance, if node1 = 010 and node2 = 110:
010 & 110 = 010 -> has the element."""
constraint = constraint_table[node2]
if constraint == 0b0:
return False
try:
value = self.el2bv[node1]
except KeyError:
return False
return constraint & value != 0
from .constraint_table import (
constraint_join,
constraint_table
)
from ..core.node_types import AssignmentNode
from .lattice import Lattice
class ReachingDefinitionsTaintAnalysis():
def __init__(self, cfg):
self.cfg = cfg
self.lattice = Lattice(cfg.nodes)
def fixpointmethod(self, cfg_node):
"""The most important part of PyT, where we perform
the variant of reaching definitions to find where sources reach.
"""
JOIN = self.join(cfg_node)
# Assignment check
if isinstance(cfg_node, AssignmentNode):
arrow_result = JOIN
# Reassignment check
if cfg_node.left_hand_side not in cfg_node.right_hand_side_variables:
# Get previous assignments of cfg_node.left_hand_side and remove them from JOIN
arrow_result = self.arrow(JOIN, cfg_node.left_hand_side)
arrow_result = arrow_result | self.lattice.el2bv[cfg_node]
constraint_table[cfg_node] = arrow_result
# Default case
else:
constraint_table[cfg_node] = JOIN
def join(self, cfg_node):
"""Joins all constraints of the ingoing nodes and returns them.
This represents the JOIN auxiliary definition from Schwartzbach."""
return constraint_join(cfg_node.ingoing)
def arrow(self, JOIN, _id):
"""Removes all previous assignments from JOIN that have the same left hand side.
This represents the arrow id definition from Schwartzbach."""
r = JOIN
for node in self.lattice.get_elements(JOIN):
if node.left_hand_side == _id:
r = r ^ self.lattice.el2bv[node]
return r
def dep(self, q_1):
"""Represents the dep mapping from Schwartzbach."""
for node in q_1.outgoing:
yield node
from .make_cfg import make_cfg
__all__ = ['make_cfg']
"""This module contains alias helper functions for the expr_visitor module."""
def as_alias_handler(alias_list):
"""Returns a list of all the names that will be called."""
list_ = list()
for alias in alias_list:
if alias.asname:
list_.append(alias.asname)
else:
list_.append(alias.name)
return list_
def handle_aliases_in_calls(name, import_alias_mapping):
"""Returns either None or the handled alias.
Used in add_module.
"""
for key, val in import_alias_mapping.items():
# e.g. Foo == Foo
# e.g. Foo.Bar startswith Foo.
if name == key or \
name.startswith(key + '.'):
# Replace key with val in name
# e.g. StarbucksVisitor.Tea -> Eataly.Tea because
# "from .nested_folder import StarbucksVisitor as Eataly"
return name.replace(key, val)
return None
def handle_aliases_in_init_files(name, import_alias_mapping):
"""Returns either None or the handled alias.
Used in add_module.
"""
for key, val in import_alias_mapping.items():
# e.g. Foo == Foo
# e.g. Foo.Bar startswith Foo.
if name == val or \
name.startswith(val + '.'):
# Replace val with key in name
# e.g. StarbucksVisitor.Tea -> Eataly.Tea because
# "from .nested_folder import StarbucksVisitor as Eataly"
return name.replace(val, key)
return None
def handle_fdid_aliases(module_or_package_name, import_alias_mapping):
"""Returns either None or the handled alias.
Used in add_module.
fdid means from directory import directory.
"""
for key, val in import_alias_mapping.items():
if module_or_package_name == val:
return key
return None
def not_as_alias_handler(names_list):
"""Returns a list of names ignoring any aliases."""
list_ = list()
for alias in names_list:
list_.append(alias.name)
return list_
def retrieve_import_alias_mapping(names_list):
"""Creates a dictionary mapping aliases to their respective name.
import_alias_names is used in module_definitions.py and visit_Call"""
import_alias_names = dict()
for alias in names_list:
if alias.asname:
import_alias_names[alias.asname] = alias.name
return import_alias_names
from collections import namedtuple
from ..core.node_types import ConnectToExitNode
SavedVariable = namedtuple(
'SavedVariable',
(
'LHS',
'RHS'
)
)
BUILTINS = (
'get',
'Flask',
'run',
'replace',
'read',
'set_cookie',
'make_response',
'SQLAlchemy',
'Column',
'execute',
'sessionmaker',
'Session',
'filter',
'call',
'render_template',
'redirect',
'url_for',
'flash',
'jsonify'
)
def return_connection_handler(nodes, exit_node):
"""Connect all return statements to the Exit node."""
for function_body_node in nodes:
if isinstance(function_body_node, ConnectToExitNode):
if exit_node not in function_body_node.outgoing:
function_body_node.connect(exit_node)
import ast
from .alias_helper import handle_aliases_in_calls
from ..core.ast_helper import (
Arguments,
get_call_names_as_string
)
from ..core.module_definitions import ModuleDefinitions
from ..core.node_types import (
AssignmentCallNode,
AssignmentNode,
BBorBInode,
ConnectToExitNode,
EntryOrExitNode,
IgnoredNode,
Node,
RestoreNode,
ReturnNode,
YieldNode
)
from .expr_visitor_helper import (
BUILTINS,
return_connection_handler,
SavedVariable
)
from ..helper_visitors import (
LabelVisitor,
RHSVisitor
)
from .stmt_visitor import StmtVisitor
from .stmt_visitor_helper import CALL_IDENTIFIER
class ExprVisitor(StmtVisitor):
def __init__(
self,
node,
project_modules,
local_modules,
filename,
module_definitions=None,
allow_local_directory_imports=True
):
"""Create an empty CFG."""
super().__init__(allow_local_directory_imports=allow_local_directory_imports)
self.project_modules = project_modules
self.local_modules = local_modules if self._allow_local_modules else []
self.filenames = [filename]
self.blackbox_assignments = set()
self.nodes = list()
self.function_call_index = 0
self.undecided = False
self.function_names = list()
self.function_return_stack = list()
self.module_definitions_stack = list()
self.prev_nodes_to_avoid = list()
self.last_control_flow_nodes = list()
# Are we already in a module?
if module_definitions:
self.init_function_cfg(node, module_definitions)
else:
self.init_cfg(node)
def init_cfg(self, node):
self.module_definitions_stack.append(ModuleDefinitions(filename=self.filenames[-1]))
entry_node = self.append_node(EntryOrExitNode('Entry module'))
module_statements = self.visit(node)
if not module_statements:
raise Exception('Empty module. It seems that your file is empty,' +
'there is nothing to analyse.')
exit_node = self.append_node(EntryOrExitNode('Exit module'))
if isinstance(module_statements, IgnoredNode):
entry_node.connect(exit_node)
return
first_node = module_statements.first_statement
if CALL_IDENTIFIER not in first_node.label:
entry_node.connect(first_node)
last_nodes = module_statements.last_statements
exit_node.connect_predecessors(last_nodes)
def init_function_cfg(self, node, module_definitions):
self.module_definitions_stack.append(module_definitions)
self.function_names.append(node.name)
self.function_return_stack.append(node.name)
entry_node = self.append_node(EntryOrExitNode('Entry function'))
module_statements = self.stmt_star_handler(node.body)
exit_node = self.append_node(EntryOrExitNode('Exit function'))
if isinstance(module_statements, IgnoredNode):
entry_node.connect(exit_node)
return
first_node = module_statements.first_statement
if CALL_IDENTIFIER not in first_node.label:
entry_node.connect(first_node)
last_nodes = module_statements.last_statements
exit_node.connect_predecessors(last_nodes)
def visit_Yield(self, node):
label = LabelVisitor()
label.visit(node)
if node.value is None:
rhs_visitor_result = []
else:
rhs_visitor_result = RHSVisitor.result_for_node(node.value)
# Yield is a bit like augmented assignment to a return value
this_function_name = self.function_return_stack[-1]
LHS = 'yld_' + this_function_name
return self.append_node(YieldNode(
LHS + ' += ' + label.result,
LHS,
node,
rhs_visitor_result + [LHS],
path=self.filenames[-1])
)
def visit_YieldFrom(self, node):
return self.visit_Yield(node)
def visit_Attribute(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Name(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_NameConstant(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Str(self, node):
return IgnoredNode()
def visit_Subscript(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Tuple(self, node):
return self.visit_miscelleaneous_node(
node
)
def connect_if_allowed(
self,
previous_node,
node_to_connect_to
):
# e.g.
# while x != 10:
# if x > 0:
# print(x)
# break
# else:
# print('hest')
# print('next') # self.nodes[-1] is print('hest')
#
# So we connect to `while x!= 10` instead
if self.last_control_flow_nodes[-1]:
self.last_control_flow_nodes[-1].connect(node_to_connect_to)
self.last_control_flow_nodes[-1] = None
return
# Except in this case:
#
# if not image_name:
# return 404
# print('foo') # We do not want to connect this line with `return 404`
if previous_node is not self.prev_nodes_to_avoid[-1] and not isinstance(previous_node, ReturnNode):
previous_node.connect(node_to_connect_to)
def save_local_scope(
self,
line_number,
saved_function_call_index
):
"""Save the local scope before entering a function call by saving all the LHS's of assignments so far.
Args:
line_number(int): Of the def of the function call about to be entered into.
saved_function_call_index(int): Unique number for each call.
Returns:
saved_variables(list[SavedVariable])
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
"""
saved_variables = list()
saved_variables_so_far = set()
first_node = None
# Make e.g. save_N_LHS = assignment.LHS for each AssignmentNode
for assignment in [node for node in self.nodes
if (type(node) == AssignmentNode or
type(node) == AssignmentCallNode or
type(Node) == BBorBInode)]: # type() is used on purpose here
if assignment.left_hand_side in saved_variables_so_far:
continue
saved_variables_so_far.add(assignment.left_hand_side)
save_name = 'save_{}_{}'.format(saved_function_call_index, assignment.left_hand_side)
previous_node = self.nodes[-1]
saved_scope_node = RestoreNode(
save_name + ' = ' + assignment.left_hand_side,
save_name,
[assignment.left_hand_side],
line_number=line_number,
path=self.filenames[-1]
)
if not first_node:
first_node = saved_scope_node
self.nodes.append(saved_scope_node)
# Save LHS
saved_variables.append(SavedVariable(LHS=save_name,
RHS=assignment.left_hand_side))
self.connect_if_allowed(previous_node, saved_scope_node)
return (saved_variables, first_node)
def save_def_args_in_temp(
self,
call_args,
def_args,
line_number,
saved_function_call_index,
first_node
):
"""Save the arguments of the definition being called. Visit the arguments if they're calls.
Args:
call_args(list[ast.Name]): Of the call being made.
def_args(ast_helper.Arguments): Of the definition being called.
line_number(int): Of the call being made.
saved_function_call_index(int): Unique number for each call.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
Returns:
args_mapping(dict): A mapping of call argument to definition argument.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
"""
args_mapping = dict()
last_return_value_of_nested_call = None
# Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument
for i, call_arg in enumerate(call_args):
# If this results in an IndexError it is invalid Python
def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i]
return_value_of_nested_call = None
if isinstance(call_arg, ast.Call):
return_value_of_nested_call = self.visit(call_arg)
restore_node = RestoreNode(
def_arg_temp_name + ' = ' + return_value_of_nested_call.left_hand_side,
def_arg_temp_name,
[return_value_of_nested_call.left_hand_side],
line_number=line_number,
path=self.filenames[-1]
)
if return_value_of_nested_call in self.blackbox_assignments:
self.blackbox_assignments.add(restore_node)
else:
call_arg_label_visitor = LabelVisitor()
call_arg_label_visitor.visit(call_arg)
call_arg_rhs_visitor = RHSVisitor()
call_arg_rhs_visitor.visit(call_arg)
restore_node = RestoreNode(
def_arg_temp_name + ' = ' + call_arg_label_visitor.result,
def_arg_temp_name,
call_arg_rhs_visitor.result,
line_number=line_number,
path=self.filenames[-1]
)
# If there are no saved variables, then this is the first node
if not first_node:
first_node = restore_node
if isinstance(call_arg, ast.Call):
if last_return_value_of_nested_call:
# connect inner to other_inner in e.g. `outer(inner(image_name), other_inner(image_name))`
if isinstance(return_value_of_nested_call, BBorBInode):
last_return_value_of_nested_call.connect(return_value_of_nested_call)
else:
last_return_value_of_nested_call.connect(return_value_of_nested_call.first_node)
else:
# I should only set this once per loop, inner in e.g. `outer(inner(image_name), other_inner(image_name))`
# (inner_most_call is used when predecessor is a ControlFlowNode in connect_control_flow_node)
if isinstance(return_value_of_nested_call, BBorBInode):
first_node.inner_most_call = return_value_of_nested_call
else:
first_node.inner_most_call = return_value_of_nested_call.first_node
# We purposefully should not set this as the first_node of return_value_of_nested_call, last makes sense
last_return_value_of_nested_call = return_value_of_nested_call
self.connect_if_allowed(self.nodes[-1], restore_node)
self.nodes.append(restore_node)
if isinstance(call_arg, ast.Call):
args_mapping[return_value_of_nested_call.left_hand_side] = def_args[i]
else:
args_mapping[def_args[i]] = call_arg_label_visitor.result
return (args_mapping, first_node)
def create_local_scope_from_def_args(
self,
call_args,
def_args,
line_number,
saved_function_call_index
):
"""Create the local scope before entering the body of a function call.
Args:
call_args(list[ast.Name]): Of the call being made.
def_args(ast_helper.Arguments): Of the definition being called.
line_number(int): Of the def of the function call about to be entered into.
saved_function_call_index(int): Unique number for each call.
Note: We do not need a connect_if_allowed because of the
preceding call to save_def_args_in_temp.
"""
# Create e.g. def_arg1 = temp_N_def_arg1 for each argument
for i in range(len(call_args)):
def_arg_local_name = def_args[i]
def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i]
local_scope_node = RestoreNode(
def_arg_local_name + ' = ' + def_arg_temp_name,
def_arg_local_name,
[def_arg_temp_name],
line_number=line_number,
path=self.filenames[-1]
)
# Chain the local scope nodes together
self.nodes[-1].connect(local_scope_node)
self.nodes.append(local_scope_node)
def visit_and_get_function_nodes(
self,
definition,
first_node
):
"""Visits the nodes of a user defined function.
Args:
definition(LocalModuleDefinition): Definition of the function being added.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
Returns:
the_new_nodes(list[Node]): The nodes added while visiting the function.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
"""
len_before_visiting_func = len(self.nodes)
previous_node = self.nodes[-1]
entry_node = self.append_node(EntryOrExitNode('Function Entry ' +
definition.name))
if not first_node:
first_node = entry_node
self.connect_if_allowed(previous_node, entry_node)
function_body_connect_statements = self.stmt_star_handler(definition.node.body)
entry_node.connect(function_body_connect_statements.first_statement)
exit_node = self.append_node(EntryOrExitNode('Exit ' + definition.name))
exit_node.connect_predecessors(function_body_connect_statements.last_statements)
the_new_nodes = self.nodes[len_before_visiting_func:]
return_connection_handler(the_new_nodes, exit_node)
return (the_new_nodes, first_node)
def restore_saved_local_scope(
self,
saved_variables,
args_mapping,
line_number
):
"""Restore the previously saved variables to their original values.
Args:
saved_variables(list[SavedVariable])
args_mapping(dict): A mapping of call argument to definition argument.
line_number(int): Of the def of the function call about to be entered into.
Note: We do not need connect_if_allowed because of the
preceding call to save_local_scope.
"""
restore_nodes = list()
for var in saved_variables:
# Is var.RHS a call argument?
if var.RHS in args_mapping:
# If so, use the corresponding definition argument for the RHS of the label.
restore_nodes.append(RestoreNode(
var.RHS + ' = ' + args_mapping[var.RHS],
var.RHS,
[var.LHS],
line_number=line_number,
path=self.filenames[-1]
))
else:
# Create a node for e.g. foo = save_1_foo
restore_nodes.append(RestoreNode(
var.RHS + ' = ' + var.LHS,
var.RHS,
[var.LHS],
line_number=line_number,
path=self.filenames[-1]
))
# Chain the restore nodes
for node, successor in zip(restore_nodes, restore_nodes[1:]):
node.connect(successor)
if restore_nodes:
# Connect the last node to the first restore node
self.nodes[-1].connect(restore_nodes[0])
self.nodes.extend(restore_nodes)
return restore_nodes
def return_handler(
self,
call_node,
function_nodes,
saved_function_call_index,
first_node
):
"""Handle the return from a function during a function call.
Args:
call_node(ast.Call) : The node that calls the definition.
function_nodes(list[Node]): List of nodes of the function being called.
saved_function_call_index(int): Unique number for each call.
first_node(EntryOrExitNode or RestoreNode): Used to connect previous statements to this function.
"""
if any(isinstance(node, YieldNode) for node in function_nodes):
# Presence of a `YieldNode` means that the function is a generator
rhs_prefix = 'yld_'
elif any(isinstance(node, ConnectToExitNode) for node in function_nodes):
# Only `Return`s and `Raise`s can be of type ConnectToExitNode
rhs_prefix = 'ret_'
else:
return # No return value
# Create e.g. ~call_1 = ret_func_foo RestoreNode
LHS = CALL_IDENTIFIER + 'call_' + str(saved_function_call_index)
RHS = rhs_prefix + get_call_names_as_string(call_node.func)
return_node = RestoreNode(
LHS + ' = ' + RHS,
LHS,
[RHS],
line_number=call_node.lineno,
path=self.filenames[-1]
)
return_node.first_node = first_node
self.nodes[-1].connect(return_node)
self.nodes.append(return_node)
def process_function(self, call_node, definition):
"""Processes a user defined function when it is called.
Increments self.function_call_index each time it is called, we can refer to it as N in the comments.
Make e.g. save_N_LHS = assignment.LHS for each AssignmentNode. (save_local_scope)
Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument.
Visit the arguments if they're calls. (save_def_args_in_temp)
Create e.g. def_arg1 = temp_N_def_arg1 for each argument. (create_local_scope_from_def_args)
Visit and get function nodes. (visit_and_get_function_nodes)
Loop through each save_N_LHS node and create an e.g.
foo = save_1_foo or, if foo was a call arg, foo = arg_mapping[foo]. (restore_saved_local_scope)
Create e.g. ~call_1 = ret_func_foo RestoreNode. (return_handler)
Notes:
Page 31 in the original thesis, but changed a little.
We don't have to return the ~call_1 = ret_func_foo RestoreNode made in return_handler,
because it's the last node anyway, that we return in this function.
e.g. ret_func_foo gets assigned to visit_Return.
Args:
call_node(ast.Call) : The node that calls the definition.
definition(LocalModuleDefinition): Definition of the function being called.
Returns:
Last node in self.nodes, probably the return of the function appended to self.nodes in return_handler.
"""
self.function_call_index += 1
saved_function_call_index = self.function_call_index
def_node = definition.node
saved_variables, first_node = self.save_local_scope(
def_node.lineno,
saved_function_call_index
)
args_mapping, first_node = self.save_def_args_in_temp(
call_node.args,
Arguments(def_node.args),
call_node.lineno,
saved_function_call_index,
first_node
)
self.filenames.append(definition.path)
self.create_local_scope_from_def_args(
call_node.args,
Arguments(def_node.args),
def_node.lineno,
saved_function_call_index
)
function_nodes, first_node = self.visit_and_get_function_nodes(
definition,
first_node
)
self.filenames.pop() # Should really probably move after restore_saved_local_scope!!!
self.restore_saved_local_scope(
saved_variables,
args_mapping,
def_node.lineno
)
self.return_handler(
call_node,
function_nodes,
saved_function_call_index,
first_node
)
self.function_return_stack.pop()
return self.nodes[-1]
def visit_Call(self, node):
_id = get_call_names_as_string(node.func)
local_definitions = self.module_definitions_stack[-1]
alias = handle_aliases_in_calls(_id, local_definitions.import_alias_mapping)
if alias:
definition = local_definitions.get_definition(alias)
else:
definition = local_definitions.get_definition(_id)
# e.g. "request.args.get" -> "get"
last_attribute = _id.rpartition('.')[-1]
if definition:
if isinstance(definition.node, ast.ClassDef):
self.add_blackbox_or_builtin_call(node, blackbox=False)
elif isinstance(definition.node, ast.FunctionDef):
self.undecided = False
self.function_return_stack.append(_id)
return self.process_function(node, definition)
else:
raise Exception('Definition was neither FunctionDef or ' +
'ClassDef, cannot add the function ')
elif last_attribute not in BUILTINS:
# Mark the call as a blackbox because we don't have the definition
return self.add_blackbox_or_builtin_call(node, blackbox=True)
return self.add_blackbox_or_builtin_call(node, blackbox=False)
from .expr_visitor import ExprVisitor
class CFG():
def __init__(
self,
nodes,
blackbox_assignments,
filename
):
self.nodes = nodes
self.blackbox_assignments = blackbox_assignments
self.filename = filename
def __repr__(self):
output = ''
for x, n in enumerate(self.nodes):
output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n'))
return output
def __str__(self):
output = ''
for x, n in enumerate(self.nodes):
output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n'))
return output
def make_cfg(
tree,
project_modules,
local_modules,
filename,
module_definitions=None,
allow_local_directory_imports=True
):
visitor = ExprVisitor(
tree,
project_modules,
local_modules,
filename,
module_definitions,
allow_local_directory_imports
)
return CFG(
visitor.nodes,
visitor.blackbox_assignments,
filename
)
import ast
import random
from collections import namedtuple
from ..core.node_types import (
AssignmentCallNode,
BBorBInode,
BreakNode,
ControlFlowNode,
RestoreNode
)
CALL_IDENTIFIER = '~'
ConnectStatements = namedtuple(
'ConnectStatements',
(
'first_statement',
'last_statements',
'break_statements'
)
)
def _get_inner_most_function_call(call_node):
# Loop to inner most function call
# e.g. return scrypt.inner in `foo = scrypt.outer(scrypt.inner(image_name))`
old_call_node = None
while call_node != old_call_node:
old_call_node = call_node
if isinstance(call_node, BBorBInode):
call_node = call_node.inner_most_call
else:
try:
# e.g. save_2_blah, even when there is a save_3_blah
call_node = call_node.first_node
except AttributeError:
# No inner calls
# Possible improvement: Make new node for RestoreNode's made in process_function
# and make `self.inner_most_call = self`
# So that we can duck type and not catch an exception when there are no inner calls.
# This is what we do in BBorBInode
pass
return call_node
def _connect_control_flow_node(control_flow_node, next_node):
"""Connect a ControlFlowNode properly to the next_node."""
for last in control_flow_node.last_nodes:
if isinstance(next_node, ControlFlowNode):
last.connect(next_node.test) # connect to next if test case
elif isinstance(next_node, AssignmentCallNode):
call_node = next_node.call_node
inner_most_call_node = _get_inner_most_function_call(call_node)
last.connect(inner_most_call_node)
else:
last.connect(next_node)
def connect_nodes(nodes):
"""Connect the nodes in a list linearly."""
for n, next_node in zip(nodes, nodes[1:]):
if isinstance(n, ControlFlowNode):
_connect_control_flow_node(n, next_node)
elif isinstance(next_node, ControlFlowNode):
n.connect(next_node.test)
elif isinstance(next_node, RestoreNode):
continue
elif CALL_IDENTIFIER in next_node.label:
continue
else:
n.connect(next_node)
def _get_names(node, result):
"""Recursively finds all names."""
if isinstance(node, ast.Name):
return node.id + result
elif isinstance(node, ast.Subscript):
return result
elif isinstance(node, ast.Starred):
return _get_names(node.value, result)
else:
return _get_names(node.value, result + '.' + node.attr)
def extract_left_hand_side(target):
"""Extract the left hand side variable from a target.
Removes list indexes, stars and other left hand side elements.
"""
left_hand_side = _get_names(target, '')
left_hand_side.replace('*', '')
if '[' in left_hand_side:
index = left_hand_side.index('[')
left_hand_side = target[:index]
return left_hand_side
def get_first_node(
node,
node_not_to_step_past
):
"""
This is a super hacky way of getting the first node after a statement.
We do this because we visit a statement and keep on visiting and get something in return that is rarely the first node.
So we loop and loop backwards until we hit the statement or there is nothing to step back to.
"""
ingoing = None
i = 0
current_node = node
while current_node.ingoing:
# This is used because there may be multiple ingoing and loop will cause an infinite loop if we did [0]
i = random.randrange(len(current_node.ingoing))
# e.g. We don't want to step past the Except of an Except basic block
if current_node.ingoing[i] == node_not_to_step_past:
break
ingoing = current_node.ingoing
current_node = current_node.ingoing[i]
if ingoing:
return ingoing[i]
return current_node
def get_first_statement(node_or_tuple):
"""Find the first statement of the provided object.
Returns:
The first element in the tuple if it is a tuple.
The node if it is a node.
"""
if isinstance(node_or_tuple, tuple):
return node_or_tuple[0]
else:
return node_or_tuple
def get_last_statements(cfg_statements):
"""Retrieve the last statements from a cfg_statements list."""
if isinstance(cfg_statements[-1], ControlFlowNode):
return cfg_statements[-1].last_nodes
else:
return [cfg_statements[-1]]
def remove_breaks(last_statements):
"""Remove all break statements in last_statements."""
return [n for n in last_statements if not isinstance(n, BreakNode)]
import ast
import itertools
import os.path
from .alias_helper import (
as_alias_handler,
handle_aliases_in_init_files,
handle_fdid_aliases,
not_as_alias_handler,
retrieve_import_alias_mapping
)
from ..core.ast_helper import (
generate_ast,
get_call_names_as_string
)
from ..core.module_definitions import (
LocalModuleDefinition,
ModuleDefinition,
ModuleDefinitions
)
from ..core.node_types import (
AssignmentNode,
AssignmentCallNode,
BBorBInode,
BreakNode,
ControlFlowNode,
EntryOrExitNode,
IfNode,
IgnoredNode,
Node,
RaiseNode,
ReturnNode,
TryNode
)
from ..core.project_handler import (
get_directory_modules
)
from ..helper_visitors import (
LabelVisitor,
RHSVisitor,
VarsVisitor
)
from .stmt_visitor_helper import (
CALL_IDENTIFIER,
ConnectStatements,
connect_nodes,
extract_left_hand_side,
get_first_node,
get_first_statement,
get_last_statements,
remove_breaks
)
class StmtVisitor(ast.NodeVisitor):
def __init__(self, allow_local_directory_imports=True):
self._allow_local_modules = allow_local_directory_imports
super().__init__()
def visit_Module(self, node):
return self.stmt_star_handler(node.body)
def stmt_star_handler(
self,
stmts,
prev_node_to_avoid=None
):
"""Handle stmt* expressions in an AST node.
Links all statements together in a list of statements, accounting for statements with multiple last nodes.
"""
break_nodes = list()
cfg_statements = list()
self.prev_nodes_to_avoid.append(prev_node_to_avoid)
self.last_control_flow_nodes.append(None)
first_node = None
node_not_to_step_past = self.nodes[-1]
for stmt in stmts:
node = self.visit(stmt)
if isinstance(node, ControlFlowNode) and not isinstance(node.test, TryNode):
self.last_control_flow_nodes.append(node.test)
else:
self.last_control_flow_nodes.append(None)
if isinstance(node, ControlFlowNode):
break_nodes.extend(node.break_statements)
elif isinstance(node, BreakNode):
break_nodes.append(node)
if not isinstance(node, IgnoredNode):
cfg_statements.append(node)
if not first_node:
if isinstance(node, ControlFlowNode):
first_node = node.test
else:
first_node = get_first_node(
node,
node_not_to_step_past
)
self.prev_nodes_to_avoid.pop()
self.last_control_flow_nodes.pop()
connect_nodes(cfg_statements)
if cfg_statements:
if first_node:
first_statement = first_node
else:
first_statement = get_first_statement(cfg_statements[0])
last_statements = get_last_statements(cfg_statements)
return ConnectStatements(
first_statement=first_statement,
last_statements=last_statements,
break_statements=break_nodes
)
else: # When body of module only contains ignored nodes
return IgnoredNode()
def get_parent_definitions(self):
parent_definitions = None
if len(self.module_definitions_stack) > 1:
parent_definitions = self.module_definitions_stack[-2]
return parent_definitions
def add_to_definitions(self, node):
local_definitions = self.module_definitions_stack[-1]
parent_definitions = self.get_parent_definitions()
if parent_definitions:
parent_qualified_name = '.'.join(
parent_definitions.classes +
[node.name]
)
parent_definition = ModuleDefinition(
parent_definitions,
parent_qualified_name,
local_definitions.module_name,
self.filenames[-1]
)
parent_definition.node = node
parent_definitions.append_if_local_or_in_imports(parent_definition)
local_qualified_name = '.'.join(local_definitions.classes +
[node.name])
local_definition = LocalModuleDefinition(
local_definitions,
local_qualified_name,
None,
self.filenames[-1]
)
local_definition.node = node
local_definitions.append_if_local_or_in_imports(local_definition)
self.function_names.append(node.name)
def visit_ClassDef(self, node):
self.add_to_definitions(node)
local_definitions = self.module_definitions_stack[-1]
local_definitions.classes.append(node.name)
parent_definitions = self.get_parent_definitions()
if parent_definitions:
parent_definitions.classes.append(node.name)
self.stmt_star_handler(node.body)
local_definitions.classes.pop()
if parent_definitions:
parent_definitions.classes.pop()
return IgnoredNode()
def visit_FunctionDef(self, node):
self.add_to_definitions(node)
return IgnoredNode()
def handle_or_else(self, orelse, test):
"""Handle the orelse part of an if or try node.
Args:
orelse(list[Node])
test(Node)
Returns:
The last nodes of the orelse branch.
"""
if isinstance(orelse[0], ast.If):
control_flow_node = self.visit(orelse[0])
# Prefix the if label with 'el'
control_flow_node.test.label = 'el' + control_flow_node.test.label
test.connect(control_flow_node.test)
return control_flow_node.last_nodes
else:
else_connect_statements = self.stmt_star_handler(
orelse,
prev_node_to_avoid=self.nodes[-1]
)
test.connect(else_connect_statements.first_statement)
return else_connect_statements.last_statements
def visit_If(self, node):
test = self.append_node(IfNode(
node.test,
node,
path=self.filenames[-1]
))
body_connect_stmts = self.stmt_star_handler(node.body)
if isinstance(body_connect_stmts, IgnoredNode):
body_connect_stmts = ConnectStatements(
first_statement=test,
last_statements=[],
break_statements=[]
)
test.connect(body_connect_stmts.first_statement)
if node.orelse:
orelse_last_nodes = self.handle_or_else(node.orelse, test)
body_connect_stmts.last_statements.extend(orelse_last_nodes)
else:
body_connect_stmts.last_statements.append(test) # if there is no orelse, test needs an edge to the next_node
last_statements = remove_breaks(body_connect_stmts.last_statements)
return ControlFlowNode(test, last_statements, break_statements=body_connect_stmts.break_statements)
def visit_Raise(self, node):
return self.append_node(RaiseNode(
node,
path=self.filenames[-1]
))
def visit_Return(self, node):
label = LabelVisitor()
label.visit(node)
this_function_name = self.function_return_stack[-1]
LHS = 'ret_' + this_function_name
if isinstance(node.value, ast.Call):
return_value_of_call = self.visit(node.value)
return_node = ReturnNode(
LHS + ' = ' + return_value_of_call.left_hand_side,
LHS,
node,
[return_value_of_call.left_hand_side],
path=self.filenames[-1]
)
return_value_of_call.connect(return_node)
return self.append_node(return_node)
elif node.value is not None:
rhs_visitor_result = RHSVisitor.result_for_node(node.value)
else:
rhs_visitor_result = []
return self.append_node(ReturnNode(
LHS + ' = ' + label.result,
LHS,
node,
rhs_visitor_result,
path=self.filenames[-1]
))
def handle_stmt_star_ignore_node(self, body, fallback_cfg_node):
try:
fallback_cfg_node.connect(body.first_statement)
except AttributeError:
body = ConnectStatements(
first_statement=[fallback_cfg_node],
last_statements=[fallback_cfg_node],
break_statements=[]
)
return body
def visit_Try(self, node):
try_node = self.append_node(TryNode(
node,
path=self.filenames[-1]
))
body = self.stmt_star_handler(node.body)
body = self.handle_stmt_star_ignore_node(body, try_node)
last_statements = list()
for handler in node.handlers:
try:
name = handler.type.id
except AttributeError:
name = ''
handler_node = self.append_node(Node(
'except ' + name + ':',
handler,
line_number=handler.lineno,
path=self.filenames[-1]
))
for body_node in body.last_statements:
body_node.connect(handler_node)
handler_body = self.stmt_star_handler(handler.body)
handler_body = self.handle_stmt_star_ignore_node(handler_body, handler_node)
last_statements.extend(handler_body.last_statements)
if node.orelse:
orelse_last_nodes = self.handle_or_else(node.orelse, body.last_statements[-1])
body.last_statements.extend(orelse_last_nodes)
if node.finalbody:
finalbody = self.stmt_star_handler(node.finalbody)
for last in last_statements:
last.connect(finalbody.first_statement)
for last in body.last_statements:
last.connect(finalbody.first_statement)
body.last_statements.extend(finalbody.last_statements)
last_statements.extend(remove_breaks(body.last_statements))
return ControlFlowNode(try_node, last_statements, break_statements=body.break_statements)
def assign_tuple_target(self, node, right_hand_side_variables):
new_assignment_nodes = []
remaining_variables = list(right_hand_side_variables)
remaining_targets = list(node.targets[0].elts)
remaining_values = list(node.value.elts) # May contain duplicates
def visit(target, value):
label = LabelVisitor()
label.visit(target)
rhs_visitor = RHSVisitor()
rhs_visitor.visit(value)
if isinstance(value, ast.Call):
new_ast_node = ast.Assign(target, value)
ast.copy_location(new_ast_node, node)
new_assignment_nodes.append(self.assignment_call_node(label.result, new_ast_node))
else:
label.result += ' = '
label.visit(value)
new_assignment_nodes.append(self.append_node(AssignmentNode(
label.result,
extract_left_hand_side(target),
ast.Assign(target, value),
rhs_visitor.result,
line_number=node.lineno,
path=self.filenames[-1]
)))
remaining_targets.remove(target)
remaining_values.remove(value)
for var in rhs_visitor.result:
remaining_variables.remove(var)
# Pair targets and values until a Starred node is reached
for target, value in zip(node.targets[0].elts, node.value.elts):
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
break
visit(target, value)
# If there was a Starred node, pair remaining targets and values from the end
for target, value in zip(reversed(list(remaining_targets)), reversed(list(remaining_values))):
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
break
visit(target, value)
if remaining_targets:
label = LabelVisitor()
label.handle_comma_separated(remaining_targets)
label.result += ' = '
label.handle_comma_separated(remaining_values)
for target in remaining_targets:
new_assignment_nodes.append(self.append_node(AssignmentNode(
label.result,
extract_left_hand_side(target),
ast.Assign(target, remaining_values[0]),
remaining_variables,
line_number=node.lineno,
path=self.filenames[-1]
)))
connect_nodes(new_assignment_nodes)
return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node
def assign_multi_target(self, node, right_hand_side_variables):
new_assignment_nodes = list()
for target in node.targets:
label = LabelVisitor()
label.visit(target)
left_hand_side = label.result
label.result += ' = '
label.visit(node.value)
new_assignment_nodes.append(self.append_node(AssignmentNode(
label.result,
left_hand_side,
ast.Assign(target, node.value),
right_hand_side_variables,
line_number=node.lineno,
path=self.filenames[-1]
)))
connect_nodes(new_assignment_nodes)
return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node
def visit_Assign(self, node):
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node.value)
if isinstance(node.targets[0], (ast.Tuple, ast.List)): # x,y = [1,2]
if isinstance(node.value, (ast.Tuple, ast.List)):
return self.assign_tuple_target(node, rhs_visitor.result)
elif isinstance(node.value, ast.Call):
call = None
for element in node.targets[0].elts:
label = LabelVisitor()
label.visit(element)
call = self.assignment_call_node(label.result, node)
return call
else:
label = LabelVisitor()
label.visit(node)
print('Assignment not properly handled.',
'Could result in not finding a vulnerability.',
'Assignment:', label.result)
return self.append_node(AssignmentNode(
label.result,
label.result,
node,
rhs_visitor.result,
path=self.filenames[-1]
))
elif len(node.targets) > 1: # x = y = 3
return self.assign_multi_target(node, rhs_visitor.result)
else:
if isinstance(node.value, ast.Call): # x = call()
label = LabelVisitor()
label.visit(node.targets[0])
return self.assignment_call_node(label.result, node)
else: # x = 4
label = LabelVisitor()
label.visit(node)
return self.append_node(AssignmentNode(
label.result,
extract_left_hand_side(node.targets[0]),
node,
rhs_visitor.result,
path=self.filenames[-1]
))
def visit_AnnAssign(self, node):
if node.value is None:
return IgnoredNode()
else:
assign = ast.Assign(targets=[node.target], value=node.value)
ast.copy_location(assign, node)
return self.visit(assign)
def assignment_call_node(self, left_hand_label, ast_node):
"""Handle assignments that contain a function call on its right side."""
self.undecided = True # Used for handling functions in assignments
call = self.visit(ast_node.value)
call_label = call.left_hand_side
if isinstance(call, BBorBInode):
# Necessary to know e.g.
# `image_name = image_name.replace('..', '')`
# is a reassignment.
vars_visitor = VarsVisitor()
vars_visitor.visit(ast_node.value)
call.right_hand_side_variables.extend(vars_visitor.result)
call_assignment = AssignmentCallNode(
left_hand_label + ' = ' + call_label,
left_hand_label,
ast_node,
[call.left_hand_side],
line_number=ast_node.lineno,
path=self.filenames[-1],
call_node=call
)
call.connect(call_assignment)
self.nodes.append(call_assignment)
self.undecided = False
return call_assignment
def visit_AugAssign(self, node):
label = LabelVisitor()
label.visit(node)
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node.value)
lhs = extract_left_hand_side(node.target)
return self.append_node(AssignmentNode(
label.result,
lhs,
node,
rhs_visitor.result + [lhs],
path=self.filenames[-1]
))
def loop_node_skeleton(self, test, node):
"""Common handling of looped structures, while and for."""
body_connect_stmts = self.stmt_star_handler(
node.body,
prev_node_to_avoid=self.nodes[-1]
)
test.connect(body_connect_stmts.first_statement)
test.connect_predecessors(body_connect_stmts.last_statements)
# last_nodes is used for making connections to the next node in the parent node
# this is handled in stmt_star_handler
last_nodes = list()
last_nodes.extend(body_connect_stmts.break_statements)
if node.orelse:
orelse_connect_stmts = self.stmt_star_handler(
node.orelse,
prev_node_to_avoid=self.nodes[-1]
)
test.connect(orelse_connect_stmts.first_statement)
last_nodes.extend(orelse_connect_stmts.last_statements)
else:
last_nodes.append(test) # if there is no orelse, test needs an edge to the next_node
return ControlFlowNode(test, last_nodes, list())
def visit_For(self, node):
self.undecided = False
iterator_label = LabelVisitor()
iterator_label.visit(node.iter)
target_label = LabelVisitor()
target_label.visit(node.target)
for_node = self.append_node(Node(
"for " + target_label.result + " in " + iterator_label.result + ':',
node,
path=self.filenames[-1]
))
if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names:
last_node = self.visit(node.iter)
last_node.connect(for_node)
return self.loop_node_skeleton(for_node, node)
def visit_While(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.test)
test = self.append_node(Node(
'while ' + label_visitor.result + ':',
node,
path=self.filenames[-1]
))
return self.loop_node_skeleton(test, node)
def add_blackbox_or_builtin_call(self, node, blackbox):
"""Processes a blackbox or builtin function when it is called.
Nothing gets assigned to ret_func_foo in the builtin/blackbox case.
Increments self.function_call_index each time it is called, we can refer to it as N in the comments.
Create e.g. ~call_1 = ret_func_foo RestoreNode.
Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument.
Visit the arguments if they're calls. (save_def_args_in_temp)
I do not think I care about this one actually -- Create e.g. def_arg1 = temp_N_def_arg1 for each argument.
(create_local_scope_from_def_args)
Add RestoreNode to the end of the Nodes.
Args:
node(ast.Call) : The node that calls the definition.
blackbox(bool): Whether or not it is a builtin or blackbox call.
Returns:
call_node(BBorBInode): The call node.
"""
self.function_call_index += 1
saved_function_call_index = self.function_call_index
self.undecided = False
call_label = LabelVisitor()
call_label.visit(node)
index = call_label.result.find('(')
# Create e.g. ~call_1 = ret_func_foo
LHS = CALL_IDENTIFIER + 'call_' + str(saved_function_call_index)
RHS = 'ret_' + call_label.result[:index] + '('
call_node = BBorBInode(
label='',
left_hand_side=LHS,
ast_node=node,
right_hand_side_variables=[],
line_number=node.lineno,
path=self.filenames[-1],
func_name=call_label.result[:index]
)
visual_args = list()
rhs_vars = list()
last_return_value_of_nested_call = None
for arg in itertools.chain(node.args, node.keywords):
if isinstance(arg, ast.Call):
return_value_of_nested_call = self.visit(arg)
if last_return_value_of_nested_call:
# connect inner to other_inner in e.g.
# `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))`
# I should probably loop to the inner most call of other_inner here.
try:
last_return_value_of_nested_call.connect(return_value_of_nested_call.first_node)
except AttributeError:
last_return_value_of_nested_call.connect(return_value_of_nested_call)
else:
# I should only set this once per loop, inner in e.g.
# `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))`
# (inner_most_call is used when predecessor is a ControlFlowNode in connect_control_flow_node)
call_node.inner_most_call = return_value_of_nested_call
last_return_value_of_nested_call = return_value_of_nested_call
visual_args.append(return_value_of_nested_call.left_hand_side)
rhs_vars.append(return_value_of_nested_call.left_hand_side)
else:
label = LabelVisitor()
label.visit(arg)
visual_args.append(label.result)
vv = VarsVisitor()
vv.visit(arg)
rhs_vars.extend(vv.result)
if last_return_value_of_nested_call:
# connect other_inner to outer in e.g.
# `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))`
last_return_value_of_nested_call.connect(call_node)
if len(visual_args) > 0:
for arg in visual_args:
RHS = RHS + arg + ", "
# Replace the last ", " with a )
RHS = RHS[:len(RHS) - 2] + ')'
else:
RHS = RHS + ')'
call_node.label = LHS + " = " + RHS
call_node.right_hand_side_variables = rhs_vars
# Used in get_sink_args, not using right_hand_side_variables because it is extended in assignment_call_node
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node)
call_node.args = rhs_visitor.result
if blackbox:
self.blackbox_assignments.add(call_node)
self.connect_if_allowed(self.nodes[-1], call_node)
self.nodes.append(call_node)
return call_node
def visit_With(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.items[0])
with_node = self.append_node(Node(
label_visitor.result,
node,
path=self.filenames[-1]
))
connect_statements = self.stmt_star_handler(node.body)
with_node.connect(connect_statements.first_statement)
return ControlFlowNode(
with_node,
connect_statements.last_statements,
connect_statements.break_statements
)
def visit_Break(self, node):
return self.append_node(BreakNode(
node,
path=self.filenames[-1]
))
def visit_Delete(self, node):
labelVisitor = LabelVisitor()
for expr in node.targets:
labelVisitor.visit(expr)
return self.append_node(Node(
'del ' + labelVisitor.result,
node,
path=self.filenames[-1]
))
def visit_Assert(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.test)
return self.append_node(Node(
label_visitor.result,
node,
path=self.filenames[-1]
))
def visit_Continue(self, node):
return self.visit_miscelleaneous_node(
node,
custom_label='continue'
)
def visit_Global(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Pass(self, node):
return self.visit_miscelleaneous_node(
node,
custom_label='pass'
)
def visit_miscelleaneous_node(
self,
node,
custom_label=None
):
if custom_label:
label = custom_label
else:
label_visitor = LabelVisitor()
label_visitor.visit(node)
label = label_visitor.result
return self.append_node(Node(
label,
node,
path=self.filenames[-1]
))
def visit_Expr(self, node):
return self.visit(node.value)
def append_node(self, node):
"""Append a node to the CFG and return it."""
self.nodes.append(node)
return node
def add_module( # noqa: C901
self,
module,
module_or_package_name,
local_names,
import_alias_mapping,
is_init=False,
from_from=False,
from_fdid=False
):
"""
Returns:
The ExitNode that gets attached to the CFG of the class.
"""
module_path = module[1]
parent_definitions = self.module_definitions_stack[-1]
# The only place the import_alias_mapping is updated
parent_definitions.import_alias_mapping.update(import_alias_mapping)
parent_definitions.import_names = local_names
new_module_definitions = ModuleDefinitions(local_names, module_or_package_name)
new_module_definitions.is_init = is_init
self.module_definitions_stack.append(new_module_definitions)
# Analyse the file
self.filenames.append(module_path)
self.local_modules = get_directory_modules(module_path) if self._allow_local_modules else []
tree = generate_ast(module_path)
# module[0] is None during e.g. "from . import foo", so we must str()
self.nodes.append(EntryOrExitNode('Module Entry ' + str(module[0])))
self.visit(tree)
exit_node = self.append_node(EntryOrExitNode('Module Exit ' + str(module[0])))
# Done analysing, pop the module off
self.module_definitions_stack.pop()
self.filenames.pop()
if new_module_definitions.is_init:
for def_ in new_module_definitions.definitions:
module_def_alias = handle_aliases_in_init_files(
def_.name,
new_module_definitions.import_alias_mapping
)
parent_def_alias = handle_aliases_in_init_files(
def_.name,
parent_definitions.import_alias_mapping
)
# They should never both be set
assert not (module_def_alias and parent_def_alias)
def_name = def_.name
if parent_def_alias:
def_name = parent_def_alias
if module_def_alias:
def_name = module_def_alias
local_definitions = self.module_definitions_stack[-1]
if local_definitions != parent_definitions:
raise
if not isinstance(module_or_package_name, str):
module_or_package_name = module_or_package_name.name
if module_or_package_name:
if from_from:
qualified_name = def_name
if from_fdid:
alias = handle_fdid_aliases(module_or_package_name, import_alias_mapping)
if alias:
module_or_package_name = alias
parent_definition = ModuleDefinition(
parent_definitions,
qualified_name,
module_or_package_name,
self.filenames[-1]
)
else:
parent_definition = ModuleDefinition(
parent_definitions,
qualified_name,
None,
self.filenames[-1]
)
else:
qualified_name = module_or_package_name + '.' + def_name
parent_definition = ModuleDefinition(
parent_definitions,
qualified_name,
parent_definitions.module_name,
self.filenames[-1]
)
parent_definition.node = def_.node
parent_definitions.definitions.append(parent_definition)
else:
parent_definition = ModuleDefinition(
parent_definitions,
def_name,
parent_definitions.module_name,
self.filenames[-1]
)
parent_definition.node = def_.node
parent_definitions.definitions.append(parent_definition)
return exit_node
def from_directory_import(
self,
module,
real_names,
local_names,
import_alias_mapping,
skip_init=False
):
"""
Directories don't need to be packages.
"""
module_path = module[1]
init_file_location = os.path.join(module_path, '__init__.py')
init_exists = os.path.isfile(init_file_location)
if init_exists and not skip_init:
package_name = os.path.split(module_path)[1]
return self.add_module(
(module[0], init_file_location),
package_name,
local_names,
import_alias_mapping,
is_init=True,
from_from=True
)
for real_name in real_names:
full_name = os.path.join(module_path, real_name)
if os.path.isdir(full_name):
new_init_file_location = os.path.join(full_name, '__init__.py')
if os.path.isfile(new_init_file_location):
self.add_module(
(real_name, new_init_file_location),
real_name,
local_names,
import_alias_mapping,
is_init=True,
from_from=True,
from_fdid=True
)
else:
raise Exception('from anything import directory needs an __init__.py file in directory')
else:
file_module = (real_name, full_name + '.py')
self.add_module(
file_module,
real_name,
local_names,
import_alias_mapping,
from_from=True
)
return IgnoredNode()
def import_package(self, module, module_name, local_name, import_alias_mapping):
module_path = module[1]
init_file_location = os.path.join(module_path, '__init__.py')
init_exists = os.path.isfile(init_file_location)
if init_exists:
return self.add_module(
(module[0], init_file_location),
module_name,
local_name,
import_alias_mapping,
is_init=True
)
else:
raise Exception('import directory needs an __init__.py file')
def handle_relative_import(self, node):
"""
from A means node.level == 0
from . import B means node.level == 1
from .A means node.level == 1
"""
no_file = os.path.abspath(os.path.join(self.filenames[-1], os.pardir))
skip_init = False
if node.level == 1:
# Same directory as current file
if node.module:
name_with_dir = os.path.join(no_file, node.module.replace('.', '/'))
if not os.path.isdir(name_with_dir):
name_with_dir = name_with_dir + '.py'
# e.g. from . import X
else:
name_with_dir = no_file
# We do not want to analyse the init file of the current directory
skip_init = True
else:
parent = os.path.abspath(os.path.join(no_file, os.pardir))
if node.level > 2:
# Perform extra `cd ..` however many times
for _ in range(0, node.level - 2):
parent = os.path.abspath(os.path.join(parent, os.pardir))
if node.module:
name_with_dir = os.path.join(parent, node.module.replace('.', '/'))
if not os.path.isdir(name_with_dir):
name_with_dir = name_with_dir + '.py'
# e.g. from .. import X
else:
name_with_dir = parent
# Is it a file?
if name_with_dir.endswith('.py'):
return self.add_module(
(node.module, name_with_dir),
None,
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
from_from=True
)
return self.from_directory_import(
(node.module, name_with_dir),
not_as_alias_handler(node.names),
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
skip_init=skip_init
)
def visit_Import(self, node):
for name in node.names:
for module in self.local_modules:
if name.name == module[0]:
if os.path.isdir(module[1]):
return self.import_package(
module,
name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
return self.add_module(
module,
name.name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
for module in self.project_modules:
if name.name == module[0]:
if os.path.isdir(module[1]):
return self.import_package(
module,
name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
return self.add_module(
module,
name.name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
return IgnoredNode()
def visit_ImportFrom(self, node):
# Is it relative?
if node.level > 0:
return self.handle_relative_import(node)
else:
for module in self.local_modules:
if node.module == module[0]:
if os.path.isdir(module[1]):
return self.from_directory_import(
module,
not_as_alias_handler(node.names),
as_alias_handler(node.names)
)
return self.add_module(
module,
None,
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
from_from=True
)
for module in self.project_modules:
name = module[0]
if node.module == name:
if os.path.isdir(module[1]):
return self.from_directory_import(
module,
not_as_alias_handler(node.names),
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names)
)
return self.add_module(
module,
None,
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
from_from=True
)
return IgnoredNode()
"""This module contains helper function.
Useful when working with the ast module."""
import ast
import os
import subprocess
from functools import lru_cache
BLACK_LISTED_CALL_NAMES = ['self']
recursive = False
def _convert_to_3(path): # pragma: no cover
"""Convert python 2 file to python 3."""
try:
print('##### Trying to convert file to Python 3. #####')
subprocess.call(['2to3', '-w', path])
except subprocess.SubprocessError:
print('Check if 2to3 is installed. '
'https://docs.python.org/2/library/2to3.html')
exit(1)
@lru_cache()
def generate_ast(path):
"""Generate an Abstract Syntax Tree using the ast module.
Args:
path(str): The path to the file e.g. example/foo/bar.py
"""
if os.path.isfile(path):
with open(path, 'r') as f:
try:
return ast.parse(f.read())
except SyntaxError: # pragma: no cover
global recursive
if not recursive:
_convert_to_3(path)
recursive = True
return generate_ast(path)
else:
raise SyntaxError('The ast module can not parse the file'
' and the python 2 to 3 conversion'
' also failed.')
raise IOError('Input needs to be a file. Path: ' + path)
def _get_call_names_helper(node):
"""Recursively finds all function names."""
if isinstance(node, ast.Name):
if node.id not in BLACK_LISTED_CALL_NAMES:
yield node.id
elif isinstance(node, ast.Subscript):
yield from _get_call_names_helper(node.value)
elif isinstance(node, ast.Str):
yield node.s
elif isinstance(node, ast.Attribute):
yield node.attr
yield from _get_call_names_helper(node.value)
def get_call_names(node):
"""Get a list of call names."""
return reversed(list(_get_call_names_helper(node)))
def _list_to_dotted_string(list_of_components):
"""Convert a list to a string seperated by a dot."""
return '.'.join(list_of_components)
def get_call_names_as_string(node):
"""Get a list of call names as a string."""
return _list_to_dotted_string(get_call_names(node))
class Arguments():
"""Represents arguments of a function."""
def __init__(self, args):
"""Argument container class.
Args:
args(list(ast.args): The arguments in a function AST node.
"""
self.args = args.args
self.varargs = args.vararg
self.kwarg = args.kwarg
self.kwonlyargs = args.kwonlyargs
self.defaults = args.defaults
self.kw_defaults = args.kw_defaults
self.arguments = list()
if self.args:
self.arguments.extend([x.arg for x in self.args])
if self.varargs:
self.arguments.extend(self.varargs.arg)
if self.kwarg:
self.arguments.extend(self.kwarg.arg)
if self.kwonlyargs:
self.arguments.extend([x.arg for x in self.kwonlyargs])
def __getitem__(self, key):
return self.arguments.__getitem__(key)
def __len__(self):
return self.args.__len__()
"""This module handles module definitions
which basically is a list of module definition."""
import ast
# Contains all project definitions for a program run
# Only used in framework_adaptor.py, but modified here
project_definitions = dict()
class ModuleDefinition():
"""Handling of a definition."""
module_definitions = None
name = None
node = None
path = None
def __init__(
self,
local_module_definitions,
name,
parent_module_name,
path
):
self.module_definitions = local_module_definitions
self.parent_module_name = parent_module_name
self.path = path
if parent_module_name:
if isinstance(parent_module_name, ast.alias):
self.name = parent_module_name.name + '.' + name
else:
self.name = parent_module_name + '.' + name
else:
self.name = name
def __str__(self):
name = 'NoName'
node = 'NoNode'
if self.name:
name = self.name
if self.node:
node = str(self.node)
return "Path:" + self.path + " " + self.__class__.__name__ + ': ' + ';'.join((name, node))
class LocalModuleDefinition(ModuleDefinition):
"""A local definition."""
pass
class ModuleDefinitions():
"""A collection of module definition.
Adds to the project definitions list.
"""
def __init__(
self,
import_names=None,
module_name=None,
is_init=False,
filename=None
):
"""Optionally set import names and module name.
Module name should only be set when it is a normal import statement.
"""
self.import_names = import_names
# module_name is sometimes ast.alias or a string
self.module_name = module_name
self.is_init = is_init
self.filename = filename
self.definitions = list()
self.classes = list()
self.import_alias_mapping = dict()
def append_if_local_or_in_imports(self, definition):
"""Add definition to list.
Handles local definitions and adds to project_definitions.
"""
if isinstance(definition, LocalModuleDefinition):
self.definitions.append(definition)
elif self.import_names == ["*"]:
self.definitions.append(definition)
elif self.import_names and definition.name in self.import_names:
self.definitions.append(definition)
elif (self.import_alias_mapping and definition.name in
self.import_alias_mapping.values()):
self.definitions.append(definition)
if definition.parent_module_name:
self.definitions.append(definition)
if definition.node not in project_definitions:
project_definitions[definition.node] = definition
def get_definition(self, name):
"""Get definitions by name."""
for definition in self.definitions:
if definition.name == name:
return definition
def set_definition_node(self, node, name):
"""Set definition by name."""
definition = self.get_definition(name)
if definition:
definition.node = node
def __str__(self):
module = 'NoModuleName'
if self.module_name:
module = self.module_name
if self.definitions:
if isinstance(module, ast.alias):
return (
'Definitions: "' + '", "'
.join([str(definition) for definition in self.definitions]) +
'" and module_name: ' + module.name +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
return (
'Definitions: "' + '", "'
.join([str(definition) for definition in self.definitions]) +
'" and module_name: ' + module +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
else:
if isinstance(module, ast.alias):
return (
'import_names is ' + str(self.import_names) +
' No Definitions, module_name: ' + str(module.name) +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
return (
'import_names is ' + str(self.import_names) +
' No Definitions, module_name: ' + str(module) +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
"""This module contains all of the CFG nodes types."""
from collections import namedtuple
from ..helper_visitors import LabelVisitor
ControlFlowNode = namedtuple(
'ControlFlowNode',
(
'test',
'last_nodes',
'break_statements'
)
)
class IgnoredNode():
"""Ignored Node sent from an ast node that should not return anything."""
pass
class ConnectToExitNode():
"""A common type between raise's and return's, used in return_handler."""
pass
class Node():
"""A Control Flow Graph node that contains a list of
ingoing and outgoing nodes and a list of its variables."""
def __init__(self, label, ast_node, *, line_number=None, path):
"""Create a Node that can be used in a CFG.
Args:
label(str): The label of the node, describing its expression.
line_number(Optional[int]): The line of the expression of the Node.
"""
self.label = label
self.ast_node = ast_node
if line_number:
self.line_number = line_number
elif ast_node:
self.line_number = ast_node.lineno
else:
self.line_number = None
self.path = path
self.ingoing = list()
self.outgoing = list()
def as_dict(self):
return {
'label': self.label.encode('utf-8').decode('utf-8'),
'line_number': self.line_number,
'path': self.path,
}
def connect(self, successor):
"""Connect this node to its successor node by
setting its outgoing and the successors ingoing."""
if isinstance(self, ConnectToExitNode) and not isinstance(successor, EntryOrExitNode):
return
self.outgoing.append(successor)
successor.ingoing.append(self)
def connect_predecessors(self, predecessors):
"""Connect all nodes in predecessors to this node."""
for n in predecessors:
self.ingoing.append(n)
n.outgoing.append(self)
def __str__(self):
"""Print the label of the node."""
return ''.join((' Label: ', self.label))
def __repr__(self):
"""Print a representation of the node."""
label = ' '.join(('Label: ', self.label))
line_number = 'Line number: ' + str(self.line_number)
outgoing = ''
ingoing = ''
if self.ingoing:
ingoing = ' '.join(('ingoing:\t', str([x.label for x in self.ingoing])))
else:
ingoing = ' '.join(('ingoing:\t', '[]'))
if self.outgoing:
outgoing = ' '.join(('outgoing:\t', str([x.label for x in self.outgoing])))
else:
outgoing = ' '.join(('outgoing:\t', '[]'))
return '\n' + '\n'.join((label, line_number, ingoing, outgoing))
class BreakNode(Node):
"""CFG Node that represents a Break statement."""
def __init__(self, ast_node, *, path):
super().__init__(
self.__class__.__name__,
ast_node,
path=path
)
class IfNode(Node):
"""CFG Node that represents an If statement."""
def __init__(self, test_node, ast_node, *, path):
label_visitor = LabelVisitor()
label_visitor.visit(test_node)
super().__init__(
'if ' + label_visitor.result + ':',
ast_node,
path=path
)
class TryNode(Node):
"""CFG Node that represents a Try statement."""
def __init__(self, ast_node, *, path):
super().__init__(
'try:',
ast_node,
path=path
)
class EntryOrExitNode(Node):
"""CFG Node that represents an Exit or an Entry node."""
def __init__(self, label):
super().__init__(label, None, line_number=None, path=None)
class RaiseNode(Node, ConnectToExitNode):
"""CFG Node that represents a Raise statement."""
def __init__(self, ast_node, *, path):
label_visitor = LabelVisitor()
label_visitor.visit(ast_node)
super().__init__(
label_visitor.result,
ast_node,
path=path
)
class AssignmentNode(Node):
"""CFG Node that represents an assignment."""
def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number=None, path):
"""Create an Assignment node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
ast_node(_ast.Assign, _ast.AugAssign, _ast.Return or None)
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
"""
super().__init__(label, ast_node, line_number=line_number, path=path)
self.left_hand_side = left_hand_side
self.right_hand_side_variables = right_hand_side_variables
def __repr__(self):
output_string = super().__repr__()
output_string += '\n'
return ''.join((output_string,
'left_hand_side:\t', str(self.left_hand_side), '\n',
'right_hand_side_variables:\t', str(self.right_hand_side_variables)))
class TaintedNode(AssignmentNode):
"""CFG Node that represents a tainted node.
Only created in framework_adaptor.py and only used in `identify_triggers` of vulnerabilities.py
"""
pass
class RestoreNode(AssignmentNode):
"""Node used for handling restore nodes returning from function calls."""
def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path):
"""Create a Restore node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
"""
super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path)
class BBorBInode(AssignmentNode):
"""Node used for handling restore nodes returning from blackbox or builtin function calls."""
def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number, path, func_name):
"""Create a Restore node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
func_name(string): The string we will compare with the blackbox_mapping in vulnerabilities.py
"""
super().__init__(label, left_hand_side, ast_node, right_hand_side_variables, line_number=line_number, path=path)
self.args = list()
self.inner_most_call = self
self.func_name = func_name
class AssignmentCallNode(AssignmentNode):
"""Node used for when a call happens inside of an assignment."""
def __init__(
self,
label,
left_hand_side,
ast_node,
right_hand_side_variables,
*,
line_number,
path,
call_node
):
"""Create an Assignment Call node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
ast_node
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
call_node(BBorBInode or RestoreNode): Used in connect_control_flow_node.
"""
super().__init__(
label,
left_hand_side,
ast_node,
right_hand_side_variables,
line_number=line_number,
path=path
)
self.call_node = call_node
self.blackbox = False
class ReturnNode(AssignmentNode, ConnectToExitNode):
"""CFG node that represents a return from a call."""
def __init__(
self,
label,
left_hand_side,
ast_node,
right_hand_side_variables,
*,
path
):
"""Create a return from a call node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
ast_node
right_hand_side_variables(list[str]): A list of variables on the right hand side.
path(string): Current filename.
"""
super().__init__(
label,
left_hand_side,
ast_node,
right_hand_side_variables,
line_number=ast_node.lineno,
path=path
)
class YieldNode(AssignmentNode):
"""CFG Node that represents a yield or yield from.
The presence of a YieldNode means that a function is a generator.
"""
pass
"""Generates a list of CFGs from a path.
The module finds all python modules and generates an ast for them.
"""
import os
_local_modules = list()
def get_directory_modules(directory):
"""Return a list containing tuples of
e.g. ('__init__', 'example/import_test_project/__init__.py')
"""
if _local_modules and os.path.dirname(_local_modules[0][1]) == directory:
return _local_modules
if not os.path.isdir(directory):
# example/import_test_project/A.py -> example/import_test_project
directory = os.path.dirname(directory)
if directory == '':
return _local_modules
for path in os.listdir(directory):
if _is_python_file(path):
# A.py -> A
module_name = os.path.splitext(path)[0]
_local_modules.append((module_name, os.path.join(directory, path)))
return _local_modules
def get_modules(path, prepend_module_root=True):
"""Return a list containing tuples of
e.g. ('test_project.utils', 'example/test_project/utils.py')
"""
module_root = os.path.split(path)[1]
modules = list()
for root, directories, filenames in os.walk(path):
for filename in filenames:
if _is_python_file(filename):
directory = os.path.dirname(
os.path.realpath(
os.path.join(
root,
filename
)
)
).split(module_root)[-1].replace(
os.sep, # e.g. '/'
'.'
)
directory = directory.replace('.', '', 1)
module_name_parts = []
if prepend_module_root:
module_name_parts.append(module_root)
if directory:
module_name_parts.append(directory)
if filename == '__init__.py':
path = root
else:
module_name_parts.append(os.path.splitext(filename)[0])
path = os.path.join(root, filename)
modules.append(('.'.join(module_name_parts), path))
return modules
def _is_python_file(path):
if os.path.splitext(path)[1] == '.py':
return True
return False
from .call_visitor import CallVisitor
from .label_visitor import LabelVisitor
from .right_hand_side_visitor import RHSVisitor
from .vars_visitor import VarsVisitor
__all__ = [
'CallVisitor',
'LabelVisitor',
'RHSVisitor',
'VarsVisitor'
]
import ast
import re
from collections import defaultdict, namedtuple
from itertools import count
from ..core.ast_helper import get_call_names_as_string
from .right_hand_side_visitor import RHSVisitor
class CallVisitorResults(
namedtuple(
"CallVisitorResults",
("args", "kwargs", "unknown_args", "unknown_kwargs")
)
):
__slots__ = ()
def all_results(self):
for x in self.args:
yield from x
for x in self.kwargs.values():
yield from x
yield from self.unknown_args
yield from self.unknown_kwargs
class CallVisitor(ast.NodeVisitor):
def __init__(self, trigger_str):
self.unknown_arg_visitor = RHSVisitor()
self.unknown_kwarg_visitor = RHSVisitor()
self.argument_visitors = defaultdict(lambda: RHSVisitor())
self._trigger_str = trigger_str
def visit_Call(self, call_node):
func_name = get_call_names_as_string(call_node.func)
trigger_re = r"(^|\.){}$".format(re.escape(self._trigger_str))
if re.search(trigger_re, func_name):
seen_starred = False
for index, arg in enumerate(call_node.args):
if isinstance(arg, ast.Starred):
seen_starred = True
if seen_starred:
self.unknown_arg_visitor.visit(arg)
else:
self.argument_visitors[index].visit(arg)
for keyword in call_node.keywords:
if keyword.arg is None:
self.unknown_kwarg_visitor.visit(keyword.value)
else:
self.argument_visitors[keyword.arg].visit(keyword.value)
self.generic_visit(call_node)
@classmethod
def get_call_visit_results(cls, trigger_str, node):
visitor = cls(trigger_str)
visitor.visit(node)
arg_results = []
for i in count():
try:
arg_results.append(set(visitor.argument_visitors.pop(i).result))
except KeyError:
break
return CallVisitorResults(
arg_results,
{k: set(v.result) for k, v in visitor.argument_visitors.items()},
set(visitor.unknown_arg_visitor.result),
set(visitor.unknown_kwarg_visitor.result),
)
import ast
class LabelVisitor(ast.NodeVisitor):
def __init__(self):
self.result = ''
def handle_comma_separated(self, comma_separated_list):
if comma_separated_list:
for element in range(len(comma_separated_list)-1):
self.visit(comma_separated_list[element])
self.result += ', '
self.visit(comma_separated_list[-1])
def visit_Tuple(self, node):
self.result += '('
self.handle_comma_separated(node.elts)
self.result += ')'
def visit_List(self, node):
self.result += '['
self.handle_comma_separated(node.elts)
self.result += ']'
def visit_Raise(self, node):
self.result += 'raise'
if node.exc:
self.result += ' '
self.visit(node.exc)
if node.cause:
self.result += ' from '
self.visit(node.cause)
def visit_withitem(self, node):
self.result += 'with '
self.visit(node.context_expr)
if node.optional_vars:
self.result += ' as '
self.visit(node.optional_vars)
def visit_Return(self, node):
if node.value:
self.visit(node.value)
def visit_Assign(self, node):
for target in node.targets:
self.visit(target)
self.result = ' '.join((self.result, '='))
self.insert_space()
self.visit(node.value)
def visit_AugAssign(self, node):
self.visit(node.target)
self.insert_space()
self.visit(node.op)
self.result += '='
self.insert_space()
self.visit(node.value)
def visit_Compare(self, node):
self.visit(node.left)
self.insert_space()
for op, com in zip(node.ops, node.comparators):
self.visit(op)
self.insert_space()
self.visit(com)
self.insert_space()
self.result = self.result.rstrip()
def visit_BinOp(self, node):
self.visit(node.left)
self.insert_space()
self.visit(node.op)
self.insert_space()
self.visit(node.right)
def visit_UnaryOp(self, node):
self.visit(node.op)
self.visit(node.operand)
def visit_BoolOp(self, node):
for i, value in enumerate(node.values):
if i == len(node.values)-1:
self.visit(value)
else:
self.visit(value)
self.visit(node.op)
def comprehensions(self, node):
self.visit(node.elt)
for expression in node.generators:
self.result += ' for '
self.visit(expression.target)
self.result += ' in '
self.visit(expression.iter)
def visit_GeneratorExp(self, node):
self.result += '('
self.comprehensions(node)
self.result += ')'
def visit_ListComp(self, node):
self.result += '['
self.comprehensions(node)
self.result += ']'
def visit_SetComp(self, node):
self.result += '{'
self.comprehensions(node)
self.result += '}'
def visit_DictComp(self, node):
self.result += '{'
self.visit(node.key)
self.result += ' : '
self.visit(node.value)
for expression in node.generators:
self.result += ' for '
self.visit(expression.target)
self.result += ' in '
self.visit(expression.iter)
self.result += '}'
def visit_Attribute(self, node):
self.visit(node.value)
self.result += '.'
self.result += node.attr
def visit_Call(self, node):
self.visit(node.func)
self.result += '('
if node.keywords and node.args:
self.handle_comma_separated(node.args)
self.result += ','
else:
self.handle_comma_separated(node.args)
self.handle_comma_separated(node.keywords)
self.result += ')'
def visit_keyword(self, node):
if node.arg:
self.result += node.arg
self.result += '='
self.visit(node.value)
def insert_space(self):
self.result += ' '
def visit_NameConstant(self, node):
self.result += str(node.value)
def visit_Subscript(self, node):
self.visit(node.value)
self.result += '['
self.slicev(node.slice)
self.result += ']'
def slicev(self, node):
if isinstance(node, ast.Slice):
if node.lower:
self.visit(node.lower)
if node.upper:
self.visit(node.upper)
if node.step:
self.visit(node.step)
elif isinstance(node, ast.ExtSlice):
if node.dims:
for d in node.dims:
self.visit(d)
else:
self.visit(node.value)
# operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv
def visit_Add(self, node):
self.result += '+'
def visit_Sub(self, node):
self.result += '-'
def visit_Mult(self, node):
self.result += '*'
def vist_MatMult(self, node):
self.result += 'x'
def visit_Div(self, node):
self.result += '/'
def visit_Mod(self, node):
self.result += '%'
def visit_Pow(self, node):
self.result += '**'
def visit_LShift(self, node):
self.result += '<<'
def visit_RShift(self, node):
self.result += '>>'
def visit_BitOr(self, node):
self.result += '|'
def visit_BitXor(self, node):
self.result += '^'
def visit_BitAnd(self, node):
self.result += '&'
def visit_FloorDiv(self, node):
self.result += '//'
# cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
def visit_Eq(self, node):
self.result += '=='
def visit_Gt(self, node):
self.result += '>'
def visit_Lt(self, node):
self.result += '<'
def visit_NotEq(self, node):
self.result += '!='
def visit_GtE(self, node):
self.result += '>='
def visit_LtE(self, node):
self.result += '<='
def visit_Is(self, node):
self.result += 'is'
def visit_IsNot(self, node):
self.result += 'is not'
def visit_In(self, node):
self.result += 'in'
def visit_NotIn(self, node):
self.result += 'not in'
# unaryop = Invert | Not | UAdd | USub
def visit_Invert(self, node):
self.result += '~'
def visit_Not(self, node):
self.result += 'not '
def visit_UAdd(self, node):
self.result += '+'
def visit_USub(self, node):
self.result += '-'
# boolop = And | Or
def visit_And(self, node):
self.result += ' and '
def visit_Or(self, node):
self.result += ' or '
def visit_Num(self, node):
self.result += str(node.n)
def visit_Name(self, node):
self.result += node.id
def visit_Str(self, node):
self.result += "'" + node.s + "'"
def visit_joined_str(self, node, surround=True):
for val in node.values:
if isinstance(val, ast.Str):
self.result += val.s
else:
self.visit(val)
def visit_JoinedStr(self, node):
"""
JoinedStr(expr* values)
"""
self.result += "f\'"
self.visit_joined_str(node)
self.result += "'"
def visit_FormattedValue(self, node):
"""
FormattedValue(expr value, int? conversion, expr? format_spec)
"""
self.result += '{'
self.visit(node.value)
self.result += {
-1: '', # no formatting
97: '!a', # ascii formatting
114: '!r', # repr formatting
115: '!s', # string formatting
}[node.conversion]
if node.format_spec:
self.result += ':'
self.visit_joined_str(node.format_spec)
self.result += '}'
def visit_Starred(self, node):
self.result += '*'
self.visit(node.value)
"""Contains a class that finds all names.
Used to find all variables on a right hand side(RHS) of assignment.
"""
import ast
class RHSVisitor(ast.NodeVisitor):
"""Visitor collecting all names."""
def __init__(self):
"""Initialize result as list."""
self.result = list()
def visit_Name(self, node):
self.result.append(node.id)
def visit_Call(self, node):
if node.args:
for arg in node.args:
self.visit(arg)
if node.keywords:
for keyword in node.keywords:
self.visit(keyword)
@classmethod
def result_for_node(cls, node):
visitor = cls()
visitor.visit(node)
return visitor.result
import ast
import itertools
from ..core.ast_helper import get_call_names
class VarsVisitor(ast.NodeVisitor):
def __init__(self):
self.result = list()
def visit_Name(self, node):
self.result.append(node.id)
def visit_BoolOp(self, node):
for v in node.values:
self.visit(v)
def visit_BinOp(self, node):
self.visit(node.left)
self.visit(node.right)
def visit_UnaryOp(self, node):
self.visit(node.operand)
def visit_Lambda(self, node):
self.visit(node.body)
def visit_IfExp(self, node):
self.visit(node.test)
self.visit(node.body)
self.visit(node.orelse)
def visit_Dict(self, node):
for k in node.keys:
if k is not None:
self.visit(k)
for v in node.values:
self.visit(v)
def visit_Set(self, node):
for e in node.elts:
self.visit(e)
def comprehension(self, node):
self.visit(node.target)
self.visit(node.iter)
for c in node.ifs:
self.visit(c)
def visit_ListComp(self, node):
self.visit(node.elt)
for gen in node.generators:
self.comprehension(gen)
def visit_SetComp(self, node):
self.visit(node.elt)
for gen in node.generators:
self.comprehension(gen)
def visit_DictComp(self, node):
self.visit(node.key)
self.visit(node.value)
for gen in node.generators:
self.comprehension(gen)
def visit_GeneratorComp(self, node):
self.visit(node.elt)
for gen in node.generators:
self.comprehension(gen)
def visit_Await(self, node):
self.visit(node.value)
def visit_Yield(self, node):
if node.value:
self.visit(node.value)
def visit_YieldFrom(self, node):
self.visit(node.value)
def visit_Compare(self, node):
self.visit(node.left)
for c in node.comparators:
self.visit(c)
def visit_Call(self, node):
# This will not visit Flask in Flask(__name__) but it will visit request in `request.args.get()
if not isinstance(node.func, ast.Name):
self.visit(node.func)
for arg in itertools.chain(node.args, node.keywords):
if isinstance(arg, ast.Call):
if isinstance(arg.func, ast.Name):
# We can't just visit because we need to add 'ret_'
self.result.append('ret_' + arg.func.id)
elif isinstance(arg.func, ast.Attribute):
# e.g. html.replace('{{ param }}', param)
# func.attr is replace
# func.value.id is html
# We want replace
self.result.append('ret_' + arg.func.attr)
else:
# Deal with it when we have code that triggers it.
raise
else:
self.visit(arg)
def visit_Attribute(self, node):
if not isinstance(node.value, ast.Name):
self.visit(node.value)
else:
self.result.append(node.value.id)
def slicev(self, node):
if isinstance(node, ast.Slice):
if node.lower:
self.visit(node.lower)
if node.upper:
self.visit(node.upper)
if node.step:
self.visit(node.step)
elif isinstance(node, ast.ExtSlice):
if node.dims:
for d in node.dims:
self.visit(d)
else:
self.visit(node.value)
def visit_Subscript(self, node):
if isinstance(node.value, ast.Attribute):
# foo.bar[1]
self.result.append(list(get_call_names(node.value))[0])
self.visit(node.value)
self.slicev(node.slice)
def visit_Starred(self, node):
self.visit(node.value)
def visit_List(self, node):
for el in node.elts:
self.visit(el)
def visit_Tuple(self, node):
for el in node.elts:
self.visit(el)
import argparse
import os
import sys
default_blackbox_mapping_file = os.path.join(
os.path.dirname(__file__),
'vulnerability_definitions',
'blackbox_mapping.json'
)
default_trigger_word_file = os.path.join(
os.path.dirname(__file__),
'vulnerability_definitions',
'all_trigger_words.pyt'
)
def _add_required_group(parser):
required_group = parser.add_argument_group('required arguments')
required_group.add_argument(
'targets', metavar='targets', type=str, nargs='+',
help='source file(s) or directory(s) to be tested'
)
def _add_optional_group(parser):
optional_group = parser.add_argument_group('optional arguments')
optional_group.add_argument(
'-a', '--adaptor',
help='Choose a web framework adaptor: '
'Flask(Default), Django, Every or Pylons',
type=str
)
optional_group.add_argument(
'-pr', '--project-root',
help='Add project root, only important when the entry '
'file is not at the root of the project.',
type=str
)
optional_group.add_argument(
'-b', '--baseline',
help='Path of a baseline report to compare against '
'(only JSON-formatted files are accepted)',
type=str,
default=False,
metavar='BASELINE_JSON_FILE',
)
optional_group.add_argument(
'-j', '--json',
help='Prints JSON instead of report.',
action='store_true',
default=False
)
optional_group.add_argument(
'-m', '--blackbox-mapping-file',
help='Input blackbox mapping file.',
type=str,
default=default_blackbox_mapping_file
)
optional_group.add_argument(
'-t', '--trigger-word-file',
help='Input file with a list of sources and sinks',
type=str,
default=default_trigger_word_file
)
optional_group.add_argument(
'-o', '--output',
help='write report to filename',
dest='output_file',
action='store',
type=argparse.FileType('w'),
default=sys.stdout,
)
optional_group.add_argument(
'--ignore-nosec',
dest='ignore_nosec',
action='store_true',
help='do not skip lines with # nosec comments'
)
optional_group.add_argument(
'-r', '--recursive', dest='recursive',
action='store_true', help='find and process files in subdirectories'
)
optional_group.add_argument(
'-x', '--exclude',
dest='excluded_paths',
action='store',
default='',
help='Separate files with commas'
)
optional_group.add_argument(
'--dont-prepend-root',
help="In project root e.g. /app, imports are not prepended with app.*",
action='store_false',
default=True,
dest='prepend_module_root'
)
optional_group.add_argument(
'--no-local-imports',
help='If set, absolute imports must be relative to the project root. '
'If not set, modules in the same directory can be imported just by their names.',
action='store_false',
default=True,
dest='allow_local_imports'
)
def _add_print_group(parser):
print_group = parser.add_argument_group('print arguments')
print_group.add_argument(
'-trim', '--trim-reassigned-in',
help='Trims the reassigned list to just the vulnerability chain.',
action='store_true',
default=True
)
print_group.add_argument(
'-i', '--interactive',
help='Will ask you about each blackbox function call in vulnerability chains.',
action='store_true',
default=False
)
def _check_required_and_mutually_exclusive_args(parser, args):
if args.targets is None:
parser.error('The targets argument is required')
def parse_args(args):
if len(args) == 0:
args.append('-h')
parser = argparse.ArgumentParser(prog='python -m pyt')
parser._action_groups.pop()
_add_required_group(parser)
_add_optional_group(parser)
_add_print_group(parser)
args = parser.parse_args(args)
_check_required_and_mutually_exclusive_args(
parser,
args
)
return args
from .vulnerabilities import find_vulnerabilities
from .vulnerability_helper import (
get_vulnerabilities_not_in_baseline,
UImode
)
__all__ = [
'find_vulnerabilities',
'get_vulnerabilities_not_in_baseline',
'UImode'
]
import json
from collections import namedtuple
Definitions = namedtuple(
'Definitions',
(
'sources',
'sinks'
)
)
Source = namedtuple('Source', ('trigger_word'))
class Sink:
def __init__(
self, trigger, *,
unlisted_args_propagate=True, unlisted_kwargs_propagate=True,
arg_list=None, kwarg_list=None,
sanitisers=None
):
self._trigger = trigger
self.sanitisers = sanitisers or []
self.arg_list_propagates = not unlisted_args_propagate
self.kwarg_list_propagates = not unlisted_kwargs_propagate
if trigger[-1] != '(':
if self.arg_list_propagates or self.kwarg_list_propagates or arg_list or kwarg_list:
raise ValueError("Propagation options specified, but trigger word isn't a function call")
self.arg_list = set(arg_list or ())
self.kwarg_list = set(kwarg_list or ())
def arg_propagates(self, index):
in_list = index in self.arg_list
return self.arg_list_propagates == in_list
def kwarg_propagates(self, keyword):
in_list = keyword in self.kwarg_list
return self.kwarg_list_propagates == in_list
@property
def all_arguments_propagate_taint(self):
if self.arg_list or self.kwarg_list:
return False
return True
@property
def call(self):
if self._trigger[-1] == '(':
return self._trigger[:-1]
return None
@property
def trigger_word(self):
return self._trigger
@classmethod
def from_json(cls, key, data):
return cls(trigger=key, **data)
def parse(trigger_word_file):
"""Parse the file for source and sink definitions.
Returns:
A definitions tuple with sources and sinks.
"""
with open(trigger_word_file) as fd:
triggers_dict = json.load(fd)
sources = [Source(s) for s in triggers_dict['sources']]
sinks = [
Sink.from_json(trigger, data)
for trigger, data in triggers_dict['sinks'].items()
]
return Definitions(sources, sinks)
"""Module for finding vulnerabilities based on a definitions file."""
import ast
import json
from collections import defaultdict
from ..analysis.definition_chains import build_def_use_chain
from ..analysis.lattice import Lattice
from ..core.node_types import (
AssignmentNode,
BBorBInode,
IfNode,
TaintedNode
)
from ..helper_visitors import (
CallVisitor,
RHSVisitor,
VarsVisitor
)
from .trigger_definitions_parser import parse, Source
from .vulnerability_helper import (
Sanitiser,
TriggerNode,
Triggers,
vuln_factory,
VulnerabilityType,
UImode
)
def identify_triggers(
cfg,
sources,
sinks,
lattice,
nosec_lines
):
"""Identify sources, sinks and sanitisers in a CFG.
Args:
cfg(CFG): CFG to find sources, sinks and sanitisers in.
sources(tuple): list of sources, a source is a (source, sanitiser) tuple.
sinks(tuple): list of sources, a sink is a (sink, sanitiser) tuple.
nosec_lines(set): lines with # nosec whitelisting
Returns:
Triggers tuple with sink and source nodes and a sanitiser node dict.
"""
assignment_nodes = filter_cfg_nodes(cfg, AssignmentNode)
tainted_nodes = filter_cfg_nodes(cfg, TaintedNode)
tainted_trigger_nodes = [
TriggerNode(
Source('Framework function URL parameter'),
cfg_node=node
) for node in tainted_nodes
]
sources_in_file = find_triggers(assignment_nodes, sources, nosec_lines)
sources_in_file.extend(tainted_trigger_nodes)
find_secondary_sources(assignment_nodes, sources_in_file, lattice)
sinks_in_file = find_triggers(cfg.nodes, sinks, nosec_lines)
sanitiser_node_dict = build_sanitiser_node_dict(cfg, sinks_in_file)
return Triggers(sources_in_file, sinks_in_file, sanitiser_node_dict)
def filter_cfg_nodes(
cfg,
cfg_node_type
):
return [node for node in cfg.nodes if isinstance(node, cfg_node_type)]
def find_secondary_sources(
assignment_nodes,
sources,
lattice
):
"""
Sets the secondary_nodes attribute of each source in the sources list.
Args:
assignment_nodes([AssignmentNode])
sources([tuple])
lattice(Lattice): the lattice we're analysing.
"""
for source in sources:
source.secondary_nodes = find_assignments(assignment_nodes, source, lattice)
def find_assignments(
assignment_nodes,
source,
lattice
):
old = list()
# propagate reassignments of the source node
new = [source.cfg_node]
while new != old:
update_assignments(new, assignment_nodes, source.cfg_node, lattice)
old = new
# remove source node from result
del new[0]
return new
def update_assignments(
assignment_list,
assignment_nodes,
source,
lattice
):
for node in assignment_nodes:
for other in assignment_list:
if node not in assignment_list and lattice.in_constraint(other, node):
append_node_if_reassigned(assignment_list, other, node)
def append_node_if_reassigned(
assignment_list,
secondary,
node
):
if (
secondary.left_hand_side in node.right_hand_side_variables or
secondary.left_hand_side == node.left_hand_side
):
assignment_list.append(node)
def find_triggers(
nodes,
trigger_words,
nosec_lines
):
"""Find triggers from the trigger_word_list in the nodes.
Args:
nodes(list[Node]): the nodes to find triggers in.
trigger_word_list(list[Union[Sink, Source]]): list of trigger words to look for.
nosec_lines(set): lines with # nosec whitelisting
Returns:
List of found TriggerNodes
"""
trigger_nodes = list()
for node in nodes:
if node.line_number not in nosec_lines:
trigger_nodes.extend(iter(label_contains(node, trigger_words)))
return trigger_nodes
def label_contains(
node,
triggers
):
"""Determine if node contains any of the trigger_words provided.
Args:
node(Node): CFG node to check.
trigger_words(list[Union[Sink, Source]]): list of trigger words to look for.
Returns:
Iterable of TriggerNodes found. Can be multiple because multiple
trigger_words can be in one node.
"""
for trigger in triggers:
if trigger.trigger_word in node.label:
yield TriggerNode(trigger, node)
def build_sanitiser_node_dict(
cfg,
sinks_in_file
):
"""Build a dict of string -> TriggerNode pairs, where the string
is the sanitiser and the TriggerNode is a TriggerNode of the sanitiser.
Args:
cfg(CFG): cfg to traverse.
sinks_in_file(list[TriggerNode]): list of TriggerNodes containing
the sinks in the file.
Returns:
A string -> TriggerNode dict.
"""
sanitisers = list()
for sink in sinks_in_file:
sanitisers.extend(sink.sanitisers)
sanitisers_in_file = list()
for sanitiser in sanitisers:
for cfg_node in cfg.nodes:
if sanitiser in cfg_node.label:
sanitisers_in_file.append(Sanitiser(sanitiser, cfg_node))
sanitiser_node_dict = dict()
for sanitiser in sanitisers:
sanitiser_node_dict[sanitiser] = list(find_sanitiser_nodes(
sanitiser,
sanitisers_in_file
))
return sanitiser_node_dict
def find_sanitiser_nodes(
sanitiser,
sanitisers_in_file
):
"""Find nodes containing a particular sanitiser.
Args:
sanitiser(string): sanitiser to look for.
sanitisers_in_file(list[Node]): list of CFG nodes with the sanitiser.
Returns:
Iterable of sanitiser nodes.
"""
for sanitiser_tuple in sanitisers_in_file:
if sanitiser == sanitiser_tuple.trigger_word:
yield sanitiser_tuple.cfg_node
def get_sink_args(cfg_node):
if isinstance(cfg_node.ast_node, ast.Call):
rhs_visitor = RHSVisitor()
rhs_visitor.visit(cfg_node.ast_node)
return rhs_visitor.result
elif isinstance(cfg_node.ast_node, ast.Assign):
return cfg_node.right_hand_side_variables
elif isinstance(cfg_node, BBorBInode):
return cfg_node.args
else:
vv = VarsVisitor()
vv.visit(cfg_node.ast_node)
return vv.result
def get_sink_args_which_propagate(sink, ast_node):
sink_args_with_positions = CallVisitor.get_call_visit_results(sink.trigger.call, ast_node)
sink_args = []
for i, vars in enumerate(sink_args_with_positions.args):
if sink.trigger.arg_propagates(i):
sink_args.extend(vars)
if (
# Either any unspecified arg propagates
not sink.trigger.arg_list_propagates or
# or there are some propagating args which weren't passed positionally
any(1 for position in sink.trigger.arg_list if position >= len(sink_args_with_positions.args))
):
sink_args.extend(sink_args_with_positions.unknown_args)
for keyword, vars in sink_args_with_positions.kwargs.items():
if sink.trigger.kwarg_propagates(keyword):
sink_args.extend(vars)
if (
# Either any unspecified kwarg propagates
not sink.trigger.kwarg_list_propagates or
# or there are some propagating kwargs which have not been passed by keyword
sink.trigger.kwarg_list - set(sink_args_with_positions.kwargs.keys())
):
sink_args.extend(sink_args_with_positions.unknown_kwargs)
return sink_args
def get_vulnerability_chains(
current_node,
sink,
def_use,
chain=[]
):
"""Traverses the def-use graph to find all paths from source to sink that cause a vulnerability.
Args:
current_node()
sink()
def_use(dict):
chain(list(Node)): A path of nodes between source and sink.
"""
for use in def_use[current_node]:
if use == sink:
yield chain
else:
vuln_chain = list(chain)
vuln_chain.append(use)
yield from get_vulnerability_chains(
use,
sink,
def_use,
vuln_chain
)
def how_vulnerable(
chain,
blackbox_mapping,
sanitiser_nodes,
potential_sanitiser,
blackbox_assignments,
ui_mode,
vuln_deets
):
"""Iterates through the chain of nodes and checks the blackbox nodes against the blackbox mapping and sanitiser dictionary.
Note: potential_sanitiser is the only hack here, it is because we do not take p-use's into account yet.
e.g. we can only say potentially instead of definitely sanitised in the path_traversal_sanitised_2.py test.
Args:
chain(list(Node)): A path of nodes between source and sink.
blackbox_mapping(dict): A map of blackbox functions containing whether or not they propagate taint.
sanitiser_nodes(set): A set of nodes that are sanitisers for the sink.
potential_sanitiser(Node): An if or elif node that can potentially cause sanitisation.
blackbox_assignments(set[AssignmentNode]): set of blackbox assignments, includes the ReturnNode's of BBorBInode's.
ui_mode(UImode): determines if we interact with the user when we don't already have a blackbox mapping available.
vuln_deets(dict): vulnerability details.
Returns:
A VulnerabilityType depending on how vulnerable the chain is.
"""
for i, current_node in enumerate(chain):
if current_node in sanitiser_nodes:
vuln_deets['sanitiser'] = current_node
vuln_deets['confident'] = True
return VulnerabilityType.SANITISED
if isinstance(current_node, BBorBInode):
if current_node.func_name in blackbox_mapping['propagates']:
continue
elif current_node.func_name in blackbox_mapping['does_not_propagate']:
return VulnerabilityType.FALSE
elif ui_mode == UImode.INTERACTIVE:
user_says = input(
'Is the return value of {} with tainted argument "{}" vulnerable? (Y/n)'.format(
current_node.label,
chain[i - 1].left_hand_side
)
).lower()
if user_says.startswith('n'):
blackbox_mapping['does_not_propagate'].append(current_node.func_name)
return VulnerabilityType.FALSE
blackbox_mapping['propagates'].append(current_node.func_name)
else:
vuln_deets['unknown_assignment'] = current_node
return VulnerabilityType.UNKNOWN
if potential_sanitiser:
vuln_deets['sanitiser'] = potential_sanitiser
vuln_deets['confident'] = False
return VulnerabilityType.SANITISED
return VulnerabilityType.TRUE
def get_tainted_node_in_sink_args(
sink_args,
nodes_in_constaint
):
if not sink_args:
return None
# Starts with the node closest to the sink
for node in nodes_in_constaint:
if node.left_hand_side in sink_args:
return node
def get_vulnerability(
source,
sink,
triggers,
lattice,
cfg,
ui_mode,
blackbox_mapping
):
"""Get vulnerability between source and sink if it exists.
Uses triggers to find sanitisers.
Note: When a secondary node is in_constraint with the sink
but not the source, the secondary is a save_N_LHS
node made in process_function in expr_visitor.
Args:
source(TriggerNode): TriggerNode of the source.
sink(TriggerNode): TriggerNode of the sink.
triggers(Triggers): Triggers of the CFG.
lattice(Lattice): the lattice we're analysing.
cfg(CFG): .blackbox_assignments used in is_unknown, .nodes used in build_def_use_chain
ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all.
blackbox_mapping(dict): A map of blackbox functions containing whether or not they propagate taint.
Returns:
A Vulnerability if it exists, else None
"""
nodes_in_constaint = [secondary for secondary in reversed(source.secondary_nodes)
if lattice.in_constraint(secondary,
sink.cfg_node)]
nodes_in_constaint.append(source.cfg_node)
if sink.trigger.all_arguments_propagate_taint:
sink_args = get_sink_args(sink.cfg_node)
else:
sink_args = get_sink_args_which_propagate(sink, sink.cfg_node.ast_node)
tainted_node_in_sink_arg = get_tainted_node_in_sink_args(
sink_args,
nodes_in_constaint,
)
if tainted_node_in_sink_arg:
vuln_deets = {
'source': source.cfg_node,
'source_trigger_word': source.trigger_word,
'sink': sink.cfg_node,
'sink_trigger_word': sink.trigger_word,
'reassignment_nodes': source.secondary_nodes
}
sanitiser_nodes = set()
potential_sanitiser = None
if sink.sanitisers:
for sanitiser in sink.sanitisers:
for cfg_node in triggers.sanitiser_dict[sanitiser]:
if isinstance(cfg_node, AssignmentNode):
sanitiser_nodes.add(cfg_node)
elif isinstance(cfg_node, IfNode):
potential_sanitiser = cfg_node
def_use = build_def_use_chain(
cfg.nodes,
lattice
)
for chain in get_vulnerability_chains(
source.cfg_node,
sink.cfg_node,
def_use
):
vulnerability_type = how_vulnerable(
chain,
blackbox_mapping,
sanitiser_nodes,
potential_sanitiser,
cfg.blackbox_assignments,
ui_mode,
vuln_deets
)
if vulnerability_type == VulnerabilityType.FALSE:
continue
if ui_mode != UImode.NORMAL:
vuln_deets['reassignment_nodes'] = chain
return vuln_factory(vulnerability_type)(**vuln_deets)
return None
def find_vulnerabilities_in_cfg(
cfg,
definitions,
lattice,
ui_mode,
blackbox_mapping,
vulnerabilities_list,
nosec_lines
):
"""Find vulnerabilities in a cfg.
Args:
cfg(CFG): The CFG to find vulnerabilities in.
definitions(trigger_definitions_parser.Definitions): Source and sink definitions.
lattice(Lattice): the lattice we're analysing.
ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all.
blackbox_mapping(dict): A map of blackbox functions containing whether or not they propagate taint.
vulnerabilities_list(list): That we append to when we find vulnerabilities.
nosec_lines(dict): filenames mapped to their nosec lines
"""
triggers = identify_triggers(
cfg,
definitions.sources,
definitions.sinks,
lattice,
nosec_lines[cfg.filename]
)
for sink in triggers.sinks:
for source in triggers.sources:
vulnerability = get_vulnerability(
source,
sink,
triggers,
lattice,
cfg,
ui_mode,
blackbox_mapping
)
if vulnerability:
vulnerabilities_list.append(vulnerability)
def find_vulnerabilities(
cfg_list,
ui_mode,
blackbox_mapping_file,
sources_and_sinks_file,
nosec_lines=defaultdict(set)
):
"""Find vulnerabilities in a list of CFGs from a trigger_word_file.
Args:
cfg_list(list[CFG]): the list of CFGs to scan.
ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all.
blackbox_mapping_file(str)
sources_and_sinks_file(str)
nosec_lines(dict): filenames mapped to their nosec lines
Returns:
A list of vulnerabilities.
"""
vulnerabilities = list()
definitions = parse(sources_and_sinks_file)
with open(blackbox_mapping_file) as infile:
blackbox_mapping = json.load(infile)
for cfg in cfg_list:
find_vulnerabilities_in_cfg(
cfg,
definitions,
Lattice(cfg.nodes),
ui_mode,
blackbox_mapping,
vulnerabilities,
nosec_lines
)
with open(blackbox_mapping_file, 'w') as outfile:
json.dump(blackbox_mapping, outfile, indent=4)
return vulnerabilities
"""This module contains vulnerability types, Enums, nodes and helpers."""
import json
from enum import Enum
from collections import namedtuple
from ..core.node_types import YieldNode
class VulnerabilityType(Enum):
FALSE = 0
SANITISED = 1
TRUE = 2
UNKNOWN = 3
class UImode(Enum):
INTERACTIVE = 0
NORMAL = 1
TRIM = 2
def vuln_factory(vulnerability_type):
if vulnerability_type == VulnerabilityType.UNKNOWN:
return UnknownVulnerability
elif vulnerability_type == VulnerabilityType.SANITISED:
return SanitisedVulnerability
else:
return Vulnerability
def _get_reassignment_str(reassignment_nodes):
reassignments = ''
if reassignment_nodes:
reassignments += '\nReassigned in:\n\t'
reassignments += '\n\t'.join([
'File: ' + node.path + '\n' +
'\t > Line ' + str(node.line_number) + ': ' + node.label
for node in reassignment_nodes
])
return reassignments
class Vulnerability():
def __init__(
self,
source,
source_trigger_word,
sink,
sink_trigger_word,
reassignment_nodes
):
"""Set source and sink information."""
self.source = source
self.source_trigger_word = source_trigger_word
self.sink = sink
self.sink_trigger_word = sink_trigger_word
self.reassignment_nodes = reassignment_nodes
self._remove_sink_from_secondary_nodes()
self._remove_non_propagating_yields()
def _remove_sink_from_secondary_nodes(self):
try:
self.reassignment_nodes.remove(self.sink)
except ValueError: # pragma: no cover
pass
def _remove_non_propagating_yields(self):
"""Remove yield with no variables e.g. `yield 123` and plain `yield` from vulnerability."""
for node in list(self.reassignment_nodes):
if isinstance(node, YieldNode) and len(node.right_hand_side_variables) == 1:
self.reassignment_nodes.remove(node)
def __str__(self):
"""Pretty printing of a vulnerability."""
reassigned_str = _get_reassignment_str(self.reassignment_nodes)
return (
'File: {}\n'
' > User input at line {}, source "{}":\n'
'\t {}{}\nFile: {}\n'
' > reaches line {}, sink "{}":\n'
'\t{}'.format(
self.source.path,
self.source.line_number, self.source_trigger_word,
self.source.label, reassigned_str, self.sink.path,
self.sink.line_number, self.sink_trigger_word,
self.sink.label
)
)
def as_dict(self):
return {
'source': self.source.as_dict(),
'source_trigger_word': self.source_trigger_word,
'sink': self.sink.as_dict(),
'sink_trigger_word': self.sink_trigger_word,
'type': self.__class__.__name__,
'reassignment_nodes': [node.as_dict() for node in self.reassignment_nodes]
}
class SanitisedVulnerability(Vulnerability):
def __init__(
self,
confident,
sanitiser,
**kwargs
):
super().__init__(**kwargs)
self.confident = confident
self.sanitiser = sanitiser
def __str__(self):
"""Pretty printing of a vulnerability."""
return (
super().__str__() +
'\nThis vulnerability is ' +
('' if self.confident else 'potentially ') +
'sanitised by: ' +
str(self.sanitiser)
)
def as_dict(self):
output = super().as_dict()
output['sanitiser'] = self.sanitiser.as_dict()
output['confident'] = self.confident
return output
class UnknownVulnerability(Vulnerability):
def __init__(
self,
unknown_assignment,
**kwargs
):
super().__init__(**kwargs)
self.unknown_assignment = unknown_assignment
def as_dict(self):
output = super().as_dict()
output['unknown_assignment'] = self.unknown_assignment.as_dict()
return output
def __str__(self):
"""Pretty printing of a vulnerability."""
return (
super().__str__() +
'\nThis vulnerability is unknown due to: ' +
str(self.unknown_assignment)
)
Sanitiser = namedtuple(
'Sanitiser',
(
'trigger_word',
'cfg_node'
)
)
Triggers = namedtuple(
'Triggers',
(
'sources',
'sinks',
'sanitiser_dict'
)
)
class TriggerNode():
def __init__(
self,
trigger,
cfg_node,
secondary_nodes=[]
):
self.trigger = trigger
self.cfg_node = cfg_node
self.secondary_nodes = secondary_nodes
@property
def trigger_word(self):
return self.trigger.trigger_word
@property
def sanitisers(self):
return self.trigger.sanitisers if hasattr(self.trigger, 'sanitisers') else []
def append(self, cfg_node):
if not cfg_node == self.cfg_node:
if self.secondary_nodes and cfg_node not in self.secondary_nodes:
self.secondary_nodes.append(cfg_node)
elif not self.secondary_nodes:
self.secondary_nodes = [cfg_node]
def __repr__(self):
output = 'TriggerNode('
if self.trigger_word:
output = '{} trigger_word is {}, '.format(
output,
self.trigger_word
)
return (
output +
'sanitisers are {}, '.format(self.sanitisers) +
'cfg_node is {})\n'.format(self.cfg_node)
)
def get_vulnerabilities_not_in_baseline(
vulnerabilities,
baseline_file
):
baseline = json.load(open(baseline_file))
output = list()
for vuln in vulnerabilities:
if vuln.as_dict() not in baseline['vulnerabilities']:
output.append(vuln)
return(output)
Documentation coming soon.

Sorry, the diff of this file is not supported yet

from .framework_adaptor import (
FrameworkAdaptor,
_get_func_nodes
)
from .framework_helper import (
is_django_view_function,
is_flask_route_function,
is_function,
is_function_without_leading_
)
__all__ = [
'FrameworkAdaptor',
'is_django_view_function',
'is_flask_route_function',
'is_function',
'is_function_without_leading_',
'_get_func_nodes' # Only used in framework_helper_test
]
"""A generic framework adaptor that leaves route criteria to the caller."""
import ast
from ..cfg import make_cfg
from ..core.ast_helper import Arguments
from ..core.module_definitions import project_definitions
from ..core.node_types import (
AssignmentNode,
TaintedNode
)
class FrameworkAdaptor():
"""An engine that uses the template pattern to find all
entry points in a framework and then taints their arguments.
"""
def __init__(
self,
cfg_list,
project_modules,
local_modules,
is_route_function
):
self.cfg_list = cfg_list
self.project_modules = project_modules
self.local_modules = local_modules
self.is_route_function = is_route_function
self.run()
def get_func_cfg_with_tainted_args(self, definition):
"""Build a function cfg and return it, with all arguments tainted."""
func_cfg = make_cfg(
definition.node,
self.project_modules,
self.local_modules,
definition.path,
definition.module_definitions
)
args = Arguments(definition.node.args)
if args:
function_entry_node = func_cfg.nodes[0]
function_entry_node.outgoing = list()
first_node_after_args = func_cfg.nodes[1]
first_node_after_args.ingoing = list()
# We are just going to give all the tainted args the lineno of the def
definition_lineno = definition.node.lineno
# Taint all the arguments
for i, arg in enumerate(args):
node_type = TaintedNode
if i == 0 and arg == 'self':
node_type = AssignmentNode
arg_node = node_type(
label=arg,
left_hand_side=arg,
ast_node=None,
right_hand_side_variables=[],
line_number=definition_lineno,
path=definition.path
)
function_entry_node.connect(arg_node)
# 1 and not 0 so that Entry Node remains first in the list
func_cfg.nodes.insert(1, arg_node)
arg_node.connect(first_node_after_args)
return func_cfg
def find_route_functions_taint_args(self):
"""Find all route functions and taint all of their arguments.
Yields:
CFG of each route function, with args marked as tainted.
"""
for definition in _get_func_nodes():
if self.is_route_function(definition.node):
yield self.get_func_cfg_with_tainted_args(definition)
def run(self):
"""Run find_route_functions_taint_args on each CFG."""
function_cfgs = list()
for _ in self.cfg_list:
function_cfgs.extend(self.find_route_functions_taint_args())
self.cfg_list.extend(function_cfgs)
def _get_func_nodes():
"""Get all function nodes."""
return [definition for definition in project_definitions.values()
if isinstance(definition.node, ast.FunctionDef)]
"""Provides helper functions that help with determining if a function is a route function."""
import ast
from ..core.ast_helper import get_call_names
def is_django_view_function(ast_node):
if len(ast_node.args.args):
first_arg_name = ast_node.args.args[0].arg
return first_arg_name == 'request'
return False
def is_flask_route_function(ast_node):
"""Check whether function uses a route decorator."""
for decorator in ast_node.decorator_list:
if isinstance(decorator, ast.Call):
if _get_last_of_iterable(get_call_names(decorator.func)) == 'route':
return True
return False
def is_function(function):
"""Always returns true because arg is always a function."""
return True
def is_function_without_leading_(ast_node):
if ast_node.name.startswith('_'):
return False
return True
def _get_last_of_iterable(iterable):
"""Get last element of iterable."""
item = None
for item in iterable:
pass
return item
+2
-2
Metadata-Version: 1.1
Name: python-taint
Version: 0.34
Version: 0.37
Summary: Find security vulnerabilities in Python web applications using static analysis.

@@ -9,3 +9,3 @@ Home-page: https://github.com/python-security/pyt

License: GPLv2
Download-URL: https://github.com/python-security/pyt/archive/0.34.tar.gz
Download-URL: https://github.com/python-security/pyt/archive/0.37.tar.gz
Description: Check out PyT on `GitHub <https://github.com/python-security/pyt>`_!

@@ -12,0 +12,0 @@ Keywords: security,vulnerability,web,flask,django,static-analysis,program-analysis

@@ -1,25 +0,15 @@

"""This module is the comand line tool of pyt."""
"""The comand line module of PyT."""
import argparse
import os
import sys
from datetime import date
from pprint import pprint
from collections import defaultdict
from .argument_helpers import (
default_blackbox_mapping_file,
default_trigger_word_file,
valid_date,
VulnerabilityFiles,
UImode
from .analysis.constraint_table import initialize_constraint_table
from .analysis.fixed_point import analyse
from .cfg import make_cfg
from .core.ast_helper import generate_ast
from .core.project_handler import (
get_directory_modules,
get_modules
)
from .ast_helper import generate_ast
from .baseline import get_vulnerabilities_not_in_baseline
from .constraint_table import (
initialize_constraint_table,
print_table
)
from .draw import draw_cfgs, draw_lattices
from .expr_visitor import make_cfg
from .fixed_point import analyse
from .formatters import (

@@ -29,4 +19,11 @@ json,

)
from .framework_adaptor import FrameworkAdaptor
from .framework_helper import (
from .usage import parse_args
from .vulnerabilities import (
find_vulnerabilities,
get_vulnerabilities_not_in_baseline,
UImode
)
from .vulnerabilities.vulnerability_helper import SanitisedVulnerability
from .web_frameworks import (
FrameworkAdaptor,
is_django_view_function,

@@ -37,198 +34,37 @@ is_flask_route_function,

)
from .github_search import scan_github, set_github_api_token
from .lattice import print_lattice
from .liveness import LivenessAnalysis
from .project_handler import get_directory_modules, get_modules
from .reaching_definitions import ReachingDefinitionsAnalysis
from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis
from .repo_runner import get_repos
from .save import (
cfg_to_file,
create_database,
def_use_chain_to_file,
lattice_to_file,
Output,
use_def_chain_to_file,
verbose_cfg_to_file,
vulnerabilities_to_file
)
from .vulnerabilities import find_vulnerabilities
def parse_args(args):
parser = argparse.ArgumentParser(prog='python -m pyt')
parser.set_defaults(which='')
def discover_files(targets, excluded_files, recursive=False):
included_files = list()
excluded_list = excluded_files.split(",")
for target in targets:
if os.path.isdir(target):
for root, _, files in os.walk(target):
for file in files:
if file.endswith('.py') and file not in excluded_list:
fullpath = os.path.join(root, file)
included_files.append(fullpath)
if not recursive:
break
else:
if target not in excluded_list:
included_files.append(target)
return included_files
subparsers = parser.add_subparsers()
entry_group = parser.add_mutually_exclusive_group(required=True)
entry_group.add_argument('-f', '--filepath',
help='Path to the file that should be analysed.',
type=str)
entry_group.add_argument('-gr', '--git-repos',
help='Takes a CSV file of git_url, path per entry.',
type=str)
parser.add_argument('-pr', '--project-root',
help='Add project root, this is important when the entry' +
' file is not at the root of the project.', type=str)
parser.add_argument('-d', '--draw-cfg',
help='Draw CFG and output as .pdf file.',
action='store_true')
parser.add_argument('-o', '--output-filename',
help='Output filename.', type=str)
parser.add_argument('-csv', '--csv-path', type=str,
help='Give the path of the csv file'
' repos should be added to.')
print_group = parser.add_mutually_exclusive_group()
print_group.add_argument('-p', '--print',
help='Prints the nodes of the CFG.',
action='store_true')
print_group.add_argument('-vp', '--verbose-print',
help='Verbose printing of -p.', action='store_true')
print_group.add_argument('-trim', '--trim-reassigned-in',
help='Trims the reassigned list to the vulnerability chain.',
action='store_true',
default=False)
print_group.add_argument('-i', '--interactive',
help='Will ask you about each vulnerability chain and blackbox nodes.',
action='store_true',
default=False)
parser.add_argument('-t', '--trigger-word-file',
help='Input trigger word file.',
type=str,
default=default_trigger_word_file)
parser.add_argument('-m', '--blackbox-mapping-file',
help='Input blackbox mapping file.',
type=str,
default=default_blackbox_mapping_file)
parser.add_argument('-py2', '--python-2',
help='[WARNING, EXPERIMENTAL] Turns on Python 2 mode,' +
' needed when target file(s) are written in Python 2.', action='store_true')
parser.add_argument('-l', '--log-level',
help='Choose logging level: CRITICAL, ERROR,' +
' WARNING(Default), INFO, DEBUG, NOTSET.', type=str)
parser.add_argument('-a', '--adaptor',
help='Choose an adaptor: Flask(Default), Django, Every or Pylons',
type=str)
parser.add_argument('-db', '--create-database',
help='Creates a sql file that can be used to' +
' create a database.', action='store_true')
parser.add_argument('-dl', '--draw-lattice',
nargs='+', help='Draws a lattice.')
parser.add_argument('-j', '--json',
help='Prints JSON instead of report.',
action='store_true',
default=False)
analysis_group = parser.add_mutually_exclusive_group()
analysis_group.add_argument('-li', '--liveness',
help='Run liveness analysis. Default is' +
' reaching definitions tainted version.',
action='store_true')
analysis_group.add_argument('-re', '--reaching',
help='Run reaching definitions analysis.' +
' Default is reaching definitions' +
' tainted version.', action='store_true')
analysis_group.add_argument('-rt', '--reaching-taint',
help='This is the default analysis:' +
' reaching definitions tainted version.',
action='store_true')
parser.add_argument('-ppm', '--print-project-modules',
help='Print project modules.', action='store_true')
parser.add_argument('-b', '--baseline',
help='path of a baseline report to compare against '
'(only JSON-formatted files are accepted)',
type=str,
default=False)
save_parser = subparsers.add_parser('save', help='Save menu.')
save_parser.set_defaults(which='save')
save_parser.add_argument('-fp', '--filename-prefix',
help='Filename prefix fx file_lattice.pyt',
type=str)
save_parser.add_argument('-du', '--def-use-chain',
help='Output the def-use chain(s) to file.',
action='store_true')
save_parser.add_argument('-ud', '--use-def-chain',
help='Output the use-def chain(s) to file',
action='store_true')
save_parser.add_argument('-cfg', '--control-flow-graph',
help='Output the CFGs to file.',
action='store_true')
save_parser.add_argument('-vcfg', '--verbose-control-flow-graph',
help='Output the verbose CFGs to file.',
action='store_true')
save_parser.add_argument('-an', '--analysis',
help='Output analysis results to file' +
' in form of a constraint table.',
action='store_true')
save_parser.add_argument('-la', '--lattice', help='Output lattice(s) to file.',
action='store_true')
save_parser.add_argument('-vu', '--vulnerabilities',
help='Output vulnerabilities to file.',
action='store_true')
save_parser.add_argument('-all', '--save-all',
help='Output everything to file.',
action='store_true')
search_parser = subparsers.add_parser(
'github_search',
help='Searches through github and runs PyT'
' on found repositories. This can take some time.')
search_parser.set_defaults(which='search')
search_parser.add_argument(
'-ss', '--search-string', required=True,
help='String for searching for repos on github.', type=str)
search_parser.add_argument('-sd', '--start-date',
help='Start date for repo search. '
'Criteria used is Created Date.',
type=valid_date,
default=date(2010, 1, 1))
return parser.parse_args(args)
def analyse_repo(args, github_repo, analysis_type, ui_mode):
cfg_list = list()
directory = os.path.dirname(github_repo.path)
project_modules = get_modules(directory)
local_modules = get_directory_modules(directory)
tree = generate_ast(github_repo.path)
cfg = make_cfg(
tree,
project_modules,
local_modules,
github_repo.path
def retrieve_nosec_lines(
path
):
file = open(path, 'r')
lines = file.readlines()
return set(
lineno for
(lineno, line) in enumerate(lines, start=1)
if '#nosec' in line or '# nosec' in line
)
cfg_list.append(cfg)
initialize_constraint_table(cfg_list)
analyse(cfg_list, analysis_type=analysis_type)
vulnerabilities = find_vulnerabilities(
cfg_list,
analysis_type,
ui_mode,
VulnerabilityFiles(
args.blackbox_mapping_file,
args.trigger_word_file
)
)
return vulnerabilities
def main(command_line_args=sys.argv[1:]):
def main(command_line_args=sys.argv[1:]): # noqa: C901
args = parse_args(command_line_args)
analysis = ReachingDefinitionsTaintAnalysis
if args.liveness:
analysis = LivenessAnalysis
elif args.reaching:
analysis = ReachingDefinitionsAnalysis
ui_mode = UImode.NORMAL

@@ -240,135 +76,75 @@ if args.interactive:

cfg_list = list()
if args.git_repos:
repos = get_repos(args.git_repos)
for repo in repos:
repo.clone()
vulnerabilities = analyse_repo(args, repo, analysis, ui_mode)
if args.json:
json.report(vulnerabilities, sys.stdout)
else:
text.report(vulnerabilities, sys.stdout)
if not vulnerabilities:
repo.clean_up()
exit()
files = discover_files(
args.targets,
args.excluded_paths,
args.recursive
)
nosec_lines = defaultdict(set)
if args.which == 'search':
set_github_api_token()
scan_github(
args.search_string,
args.start_date,
analysis,
analyse_repo,
args.csv_path,
ui_mode,
args
)
exit()
for path in files:
if not args.ignore_nosec:
nosec_lines[path] = retrieve_nosec_lines(path)
path = os.path.normpath(args.filepath)
if args.project_root:
directory = os.path.normpath(args.project_root)
else:
directory = os.path.dirname(path)
project_modules = get_modules(directory, prepend_module_root=args.prepend_module_root)
local_modules = get_directory_modules(directory)
tree = generate_ast(path)
directory = None
if args.project_root:
directory = os.path.normpath(args.project_root)
else:
directory = os.path.dirname(path)
project_modules = get_modules(directory)
local_modules = get_directory_modules(directory)
cfg = make_cfg(
tree,
project_modules,
local_modules,
path,
allow_local_directory_imports=args.allow_local_imports
)
cfg_list = [cfg]
tree = generate_ast(path, python_2=args.python_2)
framework_route_criteria = is_flask_route_function
if args.adaptor:
if args.adaptor.lower().startswith('e'):
framework_route_criteria = is_function
elif args.adaptor.lower().startswith('p'):
framework_route_criteria = is_function_without_leading_
elif args.adaptor.lower().startswith('d'):
framework_route_criteria = is_django_view_function
cfg_list = list()
cfg = make_cfg(
tree,
project_modules,
local_modules,
path
)
cfg_list.append(cfg)
framework_route_criteria = is_flask_route_function
if args.adaptor:
if args.adaptor.lower().startswith('e'):
framework_route_criteria = is_function
elif args.adaptor.lower().startswith('p'):
framework_route_criteria = is_function_without_leading_
elif args.adaptor.lower().startswith('d'):
framework_route_criteria = is_django_view_function
# Add all the route functions to the cfg_list
FrameworkAdaptor(cfg_list, project_modules, local_modules, framework_route_criteria)
# Add all the route functions to the cfg_list
FrameworkAdaptor(
cfg_list,
project_modules,
local_modules,
framework_route_criteria
)
initialize_constraint_table(cfg_list)
analyse(cfg_list, analysis_type=analysis)
analyse(cfg_list)
vulnerabilities = find_vulnerabilities(
cfg_list,
analysis,
ui_mode,
VulnerabilityFiles(
args.blackbox_mapping_file,
args.trigger_word_file
)
args.blackbox_mapping_file,
args.trigger_word_file,
nosec_lines
)
if args.baseline:
vulnerabilities = get_vulnerabilities_not_in_baseline(vulnerabilities, args.baseline)
vulnerabilities = get_vulnerabilities_not_in_baseline(
vulnerabilities,
args.baseline
)
if args.json:
json.report(vulnerabilities, sys.stdout)
json.report(vulnerabilities, args.output_file)
else:
text.report(vulnerabilities, sys.stdout)
text.report(vulnerabilities, args.output_file)
if args.draw_cfg:
if args.output_filename:
draw_cfgs(cfg_list, args.output_filename)
else:
draw_cfgs(cfg_list)
if args.print:
lattice = print_lattice(cfg_list, analysis)
has_unsanitized_vulnerabilities = any(not isinstance(v, SanitisedVulnerability) for v in vulnerabilities)
if has_unsanitized_vulnerabilities:
sys.exit(1)
print_table(lattice)
for i, e in enumerate(cfg_list):
print('############## CFG number: ', i)
print(e)
if args.verbose_print:
for i, e in enumerate(cfg_list):
print('############## CFG number: ', i)
print(repr(e))
if args.print_project_modules:
print('############## PROJECT MODULES ##############')
pprint(project_modules)
if args.create_database:
create_database(cfg_list, vulnerabilities)
if args.draw_lattice:
draw_lattices(cfg_list)
# Output to file
if args.which == 'save':
if args.filename_prefix:
Output.filename_prefix = args.filename_prefix
if args.save_all:
def_use_chain_to_file(cfg_list)
use_def_chain_to_file(cfg_list)
cfg_to_file(cfg_list)
verbose_cfg_to_file(cfg_list)
lattice_to_file(cfg_list, analysis)
vulnerabilities_to_file(vulnerabilities)
else:
if args.def_use_chain:
def_use_chain_to_file(cfg_list)
if args.use_def_chain:
use_def_chain_to_file(cfg_list)
if args.control_flow_graph:
cfg_to_file(cfg_list)
if args.verbose_control_flow_graph:
verbose_cfg_to_file(cfg_list)
if args.lattice:
lattice_to_file(cfg_list, analysis)
if args.vulnerabilities:
vulnerabilities_to_file(vulnerabilities)
if __name__ == '__main__':
main()
[console_scripts]
pyt = pyt:main
pyt = pyt.__main__:main
Metadata-Version: 1.1
Name: python-taint
Version: 0.34
Version: 0.37
Summary: Find security vulnerabilities in Python web applications using static analysis.

@@ -9,3 +9,3 @@ Home-page: https://github.com/python-security/pyt

License: GPLv2
Download-URL: https://github.com/python-security/pyt/archive/0.34.tar.gz
Download-URL: https://github.com/python-security/pyt/archive/0.37.tar.gz
Description: Check out PyT on `GitHub <https://github.com/python-security/pyt>`_!

@@ -12,0 +12,0 @@ Keywords: security,vulnerability,web,flask,django,static-analysis,program-analysis

@@ -7,37 +7,34 @@ MANIFEST.in

pyt/__main__.py
pyt/alias_helper.py
pyt/analysis_base.py
pyt/argument_helpers.py
pyt/ast_helper.py
pyt/baseline.py
pyt/constraint_table.py
pyt/definition_chains.py
pyt/draw.py
pyt/expr_visitor.py
pyt/expr_visitor_helper.py
pyt/fixed_point.py
pyt/framework_adaptor.py
pyt/framework_helper.py
pyt/github_search.py
pyt/label_visitor.py
pyt/lattice.py
pyt/liveness.py
pyt/module_definitions.py
pyt/node_types.py
pyt/project_handler.py
pyt/reaching_definitions.py
pyt/reaching_definitions_base.py
pyt/reaching_definitions_taint.py
pyt/repo_runner.py
pyt/right_hand_side_visitor.py
pyt/save.py
pyt/stmt_visitor.py
pyt/stmt_visitor_helper.py
pyt/trigger_definitions_parser.py
pyt/vars_visitor.py
pyt/vulnerabilities.py
pyt/vulnerability_helper.py
pyt/usage.py
pyt/analysis/__init__.py
pyt/analysis/constraint_table.py
pyt/analysis/definition_chains.py
pyt/analysis/fixed_point.py
pyt/analysis/lattice.py
pyt/analysis/reaching_definitions_taint.py
pyt/cfg/__init__.py
pyt/cfg/alias_helper.py
pyt/cfg/expr_visitor.py
pyt/cfg/expr_visitor_helper.py
pyt/cfg/make_cfg.py
pyt/cfg/stmt_visitor.py
pyt/cfg/stmt_visitor_helper.py
pyt/core/__init__.py
pyt/core/ast_helper.py
pyt/core/module_definitions.py
pyt/core/node_types.py
pyt/core/project_handler.py
pyt/formatters/__init__.py
pyt/formatters/json.py
pyt/formatters/text.py
pyt/helper_visitors/__init__.py
pyt/helper_visitors/call_visitor.py
pyt/helper_visitors/label_visitor.py
pyt/helper_visitors/right_hand_side_visitor.py
pyt/helper_visitors/vars_visitor.py
pyt/vulnerabilities/__init__.py
pyt/vulnerabilities/trigger_definitions_parser.py
pyt/vulnerabilities/vulnerabilities.py
pyt/vulnerabilities/vulnerability_helper.py
pyt/vulnerability_definitions/README.rst
pyt/vulnerability_definitions/all_trigger_words.pyt

@@ -47,3 +44,7 @@ pyt/vulnerability_definitions/blackbox_mapping.json

pyt/vulnerability_definitions/flask_trigger_words.pyt
pyt/vulnerability_definitions/test_positions.pyt
pyt/vulnerability_definitions/test_triggers.pyt
pyt/web_frameworks/__init__.py
pyt/web_frameworks/framework_adaptor.py
pyt/web_frameworks/framework_helper.py
python_taint.egg-info/PKG-INFO

@@ -53,3 +54,2 @@ python_taint.egg-info/SOURCES.txt

python_taint.egg-info/entry_points.txt
python_taint.egg-info/requires.txt
python_taint.egg-info/top_level.txt

@@ -1,4 +0,1 @@

.. image:: https://img.shields.io/badge/python-v3.6-blue.svg
:target: https://pypi.org/project/python-taint/
.. image:: https://travis-ci.org/python-security/pyt.svg?branch=master

@@ -19,2 +16,5 @@ :target: https://travis-ci.org/python-security/pyt

.. image:: https://img.shields.io/badge/python-v3.6-blue.svg
:target: https://pypi.org/project/python-taint/
Python Taint

@@ -29,18 +29,10 @@ ============

* Detect Command injection
* Detect command injection, SSRF, SQL injection, XSS, directory traveral etc.
* Detect SQL injection
* A lot of customisation possible
* Detect XSS
For a look at recent changes, please see the `changelog`_.
* Detect directory traversal
.. _changelog: https://github.com/python-security/pyt/blob/master/CHANGELOG.md
* Get a control flow graph
* Get a def-use and/or a use-def chain
* Search GitHub and analyse hits with PyT
* A lot of customisation possible
Example usage and output:

@@ -53,12 +45,67 @@

1. git clone https://github.com/python-security/pyt.git
2. cd pyt/
3. python3 setup.py install
4. pyt -h
.. code-block:: python
pip install python-taint
✨🍰✨
PyT can also be installed from source. To do so, clone the repo, and then run:
.. code-block:: python
python3 setup.py install
How It Works
============
Soon you will find a README.rst in every directory in the pyt folder, `start here`_.
.. _start here: https://github.com/python-security/pyt/tree/master/pyt
Usage
=====
.. code-block::
usage: python -m pyt [-h] [-a ADAPTOR] [-pr PROJECT_ROOT]
[-b BASELINE_JSON_FILE] [-j] [-m BLACKBOX_MAPPING_FILE]
[-t TRIGGER_WORD_FILE] [-o OUTPUT_FILE] [--ignore-nosec]
[-r] [-x EXCLUDED_PATHS] [-trim] [-i]
targets [targets ...]
required arguments:
targets source file(s) or directory(s) to be tested
optional arguments:
-a ADAPTOR, --adaptor ADAPTOR
Choose a web framework adaptor: Flask(Default),
Django, Every or Pylons
-pr PROJECT_ROOT, --project-root PROJECT_ROOT
Add project root, only important when the entry file
is not at the root of the project.
-b BASELINE_JSON_FILE, --baseline BASELINE_JSON_FILE
Path of a baseline report to compare against (only
JSON-formatted files are accepted)
-j, --json Prints JSON instead of report.
-m BLACKBOX_MAPPING_FILE, --blackbox-mapping-file BLACKBOX_MAPPING_FILE
Input blackbox mapping file.
-t TRIGGER_WORD_FILE, --trigger-word-file TRIGGER_WORD_FILE
Input file with a list of sources and sinks
-o OUTPUT_FILE, --output OUTPUT_FILE
write report to filename
--ignore-nosec do not skip lines with # nosec comments
-r, --recursive find and process files in subdirectories
-x EXCLUDED_PATHS, --exclude EXCLUDED_PATHS
Separate files with commas
print arguments:
-trim, --trim-reassigned-in
Trims the reassigned list to just the vulnerability
chain.
-i, --interactive Will ask you about each blackbox function call in
vulnerability chains.
Usage from Source
=================
Using it like a user ``python3 -m pyt -f example/vulnerable_code/XSS_call.py save -du``
Using it like a user ``python3 -m pyt examples/vulnerable_code/XSS_call.py``

@@ -71,3 +118,2 @@ Running the tests ``python3 -m tests``

Contributions

@@ -74,0 +120,0 @@ =============

@@ -5,4 +5,5 @@ [metadata]

[egg_info]
tag_svn_revision = 0
tag_date = 0
tag_build =
tag_date = 0

@@ -5,3 +5,3 @@ from setuptools import find_packages

VERSION = '0.34'
VERSION = '0.37'

@@ -34,12 +34,8 @@

keywords=['security', 'vulnerability', 'web', 'flask', 'django', 'static-analysis', 'program-analysis'],
install_requires=[
'graphviz>=0.4.10',
'requests>=2.12',
'GitPython>=2.0.8'
],
install_requires=[],
entry_points={
'console_scripts': [
'pyt = pyt:main'
'pyt = pyt.__main__:main'
]
}
)
"""This module contains alias helper functions for the expr_visitor module."""
def as_alias_handler(alias_list):
"""Returns a list of all the names that will be called."""
list_ = list()
for alias in alias_list:
if alias.asname:
list_.append(alias.asname)
else:
list_.append(alias.name)
return list_
def handle_aliases_in_calls(name, import_alias_mapping):
"""Returns either None or the handled alias.
Used in add_module.
"""
for key, val in import_alias_mapping.items():
# e.g. Foo == Foo
# e.g. Foo.Bar startswith Foo.
if name == key or \
name.startswith(key + '.'):
# Replace key with val in name
# e.g. StarbucksVisitor.Tea -> Eataly.Tea because
# "from .nested_folder import StarbucksVisitor as Eataly"
return name.replace(key, val)
return None
def handle_aliases_in_init_files(name, import_alias_mapping):
"""Returns either None or the handled alias.
Used in add_module.
"""
for key, val in import_alias_mapping.items():
# e.g. Foo == Foo
# e.g. Foo.Bar startswith Foo.
if name == val or \
name.startswith(val + '.'):
# Replace val with key in name
# e.g. StarbucksVisitor.Tea -> Eataly.Tea because
# "from .nested_folder import StarbucksVisitor as Eataly"
return name.replace(val, key)
return None
def handle_fdid_aliases(module_or_package_name, import_alias_mapping):
"""Returns either None or the handled alias.
Used in add_module.
fdid means from directory import directory.
"""
for key, val in import_alias_mapping.items():
if module_or_package_name == val:
return key
return None
def not_as_alias_handler(names_list):
"""Returns a list of names ignoring any aliases."""
list_ = list()
for alias in names_list:
list_.append(alias.name)
return list_
def retrieve_import_alias_mapping(names_list):
"""Creates a dictionary mapping aliases to their respective name.
import_alias_names is used in module_definitions.py and visit_Call"""
import_alias_names = dict()
for alias in names_list:
if alias.asname:
import_alias_names[alias.asname] = alias.name
return import_alias_names
"""This module contains a base class for the analysis component used in PyT."""
from abc import (
ABCMeta,
abstractmethod
)
class AnalysisBase(metaclass=ABCMeta):
"""Base class for fixed point analyses."""
annotated_cfg_nodes = dict()
def __init__(self, cfg):
self.cfg = cfg
self.build_lattice(cfg)
@staticmethod
@abstractmethod
def get_lattice_elements(cfg_nodes):
pass
@abstractmethod
def equal(self, value, other):
"""Define the equality for two constraint sets
that are defined by bitvectors."""
pass
@abstractmethod
def build_lattice(self, cfg):
pass
@abstractmethod
def dep(self, q_1):
"""Represents the dep mapping from Schwartzbach."""
pass
import os
from argparse import ArgumentTypeError
from collections import namedtuple
from datetime import datetime
from enum import Enum
default_blackbox_mapping_file = os.path.join(
os.path.dirname(__file__),
'vulnerability_definitions',
'blackbox_mapping.json'
)
default_trigger_word_file = os.path.join(
os.path.dirname(__file__),
'vulnerability_definitions',
'all_trigger_words.pyt'
)
def valid_date(s):
date_format = "%Y-%m-%d"
try:
return datetime.strptime(s, date_format).date()
except ValueError:
msg = "Not a valid date: '{0}'. Format: {1}".format(s, date_format)
raise ArgumentTypeError(msg)
class UImode(Enum):
INTERACTIVE = 0
NORMAL = 1
TRIM = 2
VulnerabilityFiles = namedtuple(
'VulnerabilityFiles',
(
'blackbox_mapping',
'triggers'
)
)
"""This module contains helper function.
Useful when working with the ast module."""
import ast
import os
import subprocess
BLACK_LISTED_CALL_NAMES = ['self']
recursive = False
python_2_mode = False
def convert_to_3(path): # pragma: no cover
"""Convert python 2 file to python 3."""
try:
print('##### Trying to convert file to Python 3. #####')
subprocess.call(['2to3', '-w', path])
except:
print('Check if 2to3 is installed. '
'https://docs.python.org/2/library/2to3.html')
exit(1)
def generate_ast(path, python_2=False):
"""Generate an Abstract Syntax Tree using the ast module.
Args:
path(str): The path to the file e.g. example/foo/bar.py
python_2(bool): Determines whether or not to call 2to3.
"""
# If set, it stays set.
global python_2_mode
if python_2: # pragma: no cover
python_2_mode = True
if os.path.isfile(path):
with open(path, 'r') as f:
try:
return ast.parse(f.read())
except SyntaxError: # pragma: no cover
global recursive
if not recursive:
if not python_2_mode:
convert_to_3(path)
recursive = True
return generate_ast(path)
else:
raise SyntaxError('The ast module can not parse the file'
' and the python 2 to 3 conversion'
' also failed.')
raise IOError('Input needs to be a file. Path: ' + path)
def list_to_dotted_string(list_of_components):
"""Convert a list to a string seperated by a dot."""
return '.'.join(list_of_components)
def get_call_names_helper(node, result):
"""Recursively finds all function names."""
if isinstance(node, ast.Name):
if node.id not in BLACK_LISTED_CALL_NAMES:
result.append(node.id)
return result
elif isinstance(node, ast.Call):
return result
elif isinstance(node, ast.Subscript):
return get_call_names_helper(node.value, result)
elif isinstance(node, ast.Str):
result.append(node.s)
return result
else:
result.append(node.attr)
return get_call_names_helper(node.value, result)
def get_call_names_as_string(node):
"""Get a list of call names as a string."""
return list_to_dotted_string(get_call_names(node))
def get_call_names(node):
"""Get a list of call names."""
result = list()
return reversed(get_call_names_helper(node, result))
class Arguments():
"""Represents arguments of a function."""
def __init__(self, args):
"""Create an Argument container class.
Args:
args(list(ast.args): The arguments in a function AST node.
"""
self.args = args.args
self.varargs = args.vararg
self.kwarg = args.kwarg
self.kwonlyargs = args.kwonlyargs
self.defaults = args.defaults
self.kw_defaults = args.kw_defaults
self.arguments = list()
if self.args:
self.arguments.extend([x.arg for x in self.args])
if self.varargs:
self.arguments.extend(self.varargs.arg)
if self.kwarg:
self.arguments.extend(self.kwarg.arg)
if self.kwonlyargs:
self.arguments.extend([x.arg for x in self.kwonlyargs])
def __getitem__(self, key):
return self.arguments.__getitem__(key)
def __len__(self):
return self.args.__len__()
import json
def get_vulnerabilities_not_in_baseline(vulnerabilities, baseline):
baseline = json.load(open(baseline))
output = list()
vulnerabilities =[vuln for vuln in vulnerabilities]
for vuln in vulnerabilities:
if vuln.as_dict() not in baseline['vulnerabilities']:
output.append(vuln)
return(output)
"""Global lookup table for constraints.
Uses cfg node as key and operates on bitvectors in the form of ints."""
constraint_table = dict()
def initialize_constraint_table(cfg_list):
"""Collects all given cfg nodes and initializes the table with value 0."""
for cfg in cfg_list:
constraint_table.update(dict.fromkeys(cfg.nodes, 0))
def constraint_join(cfg_nodes):
"""Looks up all cfg_nodes and joins the bitvectors by using logical or."""
r = 0
for e in cfg_nodes:
r = r | constraint_table[e]
return r
def print_table(lattice):
print('Constraint table:')
for k, v in constraint_table.items():
print(str(k) + ': ' + ','.join([str(n) for n in lattice.get_elements(v)]))
import ast
from .constraint_table import constraint_table
from .lattice import Lattice
from .node_types import AssignmentNode
from .reaching_definitions import ReachingDefinitionsAnalysis
from .vars_visitor import VarsVisitor
def get_vars(node):
vv = VarsVisitor()
if isinstance(node.ast_node, (ast.If, ast.While)):
vv.visit(node.ast_node.test)
elif isinstance(node.ast_node, (ast.ClassDef, ast.FunctionDef)):
return set()
else:
try:
vv.visit(node.ast_node)
except AttributeError: # If no ast_node
vv.result = list()
vv.result = set(vv.result)
# Filter out lvars:
for var in vv.result:
try:
if var in node.right_hand_side_variables:
yield var
except AttributeError:
yield var
def get_constraint_nodes(node, lattice):
for n in lattice.get_elements(constraint_table[node]):
if n is not node:
yield n
def build_use_def_chain(cfg_nodes):
use_def = dict()
lattice = Lattice(cfg_nodes, ReachingDefinitionsAnalysis)
for node in cfg_nodes:
definitions = list()
for constraint_node in get_constraint_nodes(node, lattice):
for var in get_vars(node):
if var in constraint_node.left_hand_side:
definitions.append((var, constraint_node))
use_def[node] = definitions
return use_def
def build_def_use_chain(cfg_nodes):
def_use = dict()
lattice = Lattice(cfg_nodes, ReachingDefinitionsAnalysis)
# For every node
for node in cfg_nodes:
# That's a definition
if isinstance(node, AssignmentNode):
# Make an empty list for it in def_use dict
def_use[node] = list()
# Get its uses
for variable in node.right_hand_side_variables:
# Loop through most of the nodes before it
for earlier_node in get_constraint_nodes(node, lattice):
# and add to the 'uses list' of each earlier node, when applicable
# 'earlier node' here being a simplification
if variable in earlier_node.left_hand_side:
def_use[earlier_node].append(node)
return def_use
"""Draws CFG."""
import argparse
from itertools import permutations
from subprocess import call
from graphviz import Digraph
from .node_types import AssignmentNode
IGNORED_LABEL_NAME_CHARACHTERS = ':'
cfg_styles = {
'graph': {
'fontsize': '16',
'fontcolor': 'black',
'bgcolor': 'transparent',
'rankdir': 'TB',
'splines': 'ortho',
'margin': '0.01',
},
'nodes': {
'fontname': 'Gotham',
'shape': 'box',
'fontcolor': 'black',
'color': 'black',
'style': 'filled',
'fillcolor': 'transparent',
},
'edges': {
'style': 'filled',
'color': 'black',
'arrowhead': 'normal',
'fontname': 'Courier',
'fontsize': '12',
'fontcolor': 'black',
}
}
lattice_styles = {
'graph': {
'fontsize': '16',
'fontcolor': 'black',
'bgcolor': 'transparent',
'rankdir': 'TB',
'splines': 'line',
'margin': '0.01',
'ranksep': '1',
},
'nodes': {
'fontname': 'Gotham',
'shape': 'none',
'fontcolor': 'black',
'color': 'black',
'style': 'filled',
'fillcolor': 'transparent',
},
'edges': {
'style': 'filled',
'color': 'black',
'arrowhead': 'none',
'fontname': 'Courier',
'fontsize': '12',
'fontcolor': 'black',
}
}
def apply_styles(graph, styles):
"""Apply styles to graph."""
graph.graph_attr.update(
('graph' in styles and styles['graph']) or {}
)
graph.node_attr.update(
('nodes' in styles and styles['nodes']) or {}
)
graph.edge_attr.update(
('edges' in styles and styles['edges']) or {}
)
return graph
def draw_cfg(cfg, output_filename='output'):
"""Draw CFG and output as pdf."""
graph = Digraph(format='pdf')
for node in cfg.nodes:
stripped_label = node.label.replace(IGNORED_LABEL_NAME_CHARACHTERS, '')
if 'Exit' in stripped_label:
graph.node(stripped_label, 'Exit', shape='none')
elif 'Entry' in stripped_label:
graph.node(stripped_label, 'Entry', shape='none')
else:
graph.node(stripped_label, stripped_label)
for ingoing_node in node.ingoing:
graph.edge(ingoing_node.label.replace(
IGNORED_LABEL_NAME_CHARACHTERS, ''), stripped_label)
graph = apply_styles(graph, cfg_styles)
graph.render(filename=output_filename)
class Node():
def __init__(self, s, parent, children=None):
self.s = s
self.parent = parent
self.children = children
def __str__(self):
return 'Node: ' + str(self.s) + ' Parent: ' + str(self.parent) + ' Children: ' + str(self.children)
def __hash__(self):
return hash(str(self.s))
def draw_node(l, graph, node):
node_label = str(node.s)
graph.node(node_label, node_label)
for child in node.children:
child_label = str(child.s)
graph.node(child_label, child_label)
if not (node_label, child_label) in l:
graph.edge(node_label, child_label, )
l.append((node_label, child_label))
draw_node(l, graph, child)
def make_lattice(s, length):
p = Node(s, None)
p.children = get_children(p, s, length)
return p
def get_children(p, s, length):
children = set()
if length < 0:
return children
for subset in permutations(s, length):
setsubset = set(subset)
append = True
for node in children:
if setsubset == node.s:
append = False
break
if append:
n = Node(setsubset, p)
n.children = get_children(n, setsubset, length-1)
children.add(n)
return children
def add_anchor(filename):
filename += '.dot'
out = list()
delimiter = '->'
with open(filename, 'r') as fd:
for line in fd:
if delimiter in line:
s = line.split(delimiter)
ss = s[0][:-1]
s[0] = ss + ':s '
ss = s[1][:-1]
s[1] = ss + ':n\n'
s.insert(1, delimiter)
out.append(''.join(s))
elif 'set()' in line:
out.append('"set()" [label="{}"]')
else:
out.append(line)
with open(filename, 'w') as fd:
for line in out:
fd.write(line)
def run_dot(filename):
filename += '.dot'
call(['dot', '-Tpdf', filename, '-o', filename.replace('.dot', '.pdf')])
def draw_lattice(cfg, output_filename='output'):
"""Draw CFG and output as pdf."""
graph = Digraph(format='pdf')
ll = [s.label for s in cfg.nodes if isinstance(s, AssignmentNode)]
root = make_lattice(ll, len(ll)-1)
l = list()
draw_node(l, graph, root)
graph = apply_styles(graph, lattice_styles)
graph.render(filename=output_filename+'.dot')
add_anchor(output_filename)
run_dot(output_filename)
def draw_lattice_from_labels(labels, output_filename):
graph = Digraph(format='pdf')
root = make_lattice(labels, len(labels)-1)
l = list()
draw_node(l, graph, root)
graph = apply_styles(graph, lattice_styles)
graph.render(filename=output_filename+'.dot')
add_anchor(output_filename)
run_dot(output_filename)
def draw_lattices(cfg_list, output_prefix='output'):
for i, cfg in enumerate(cfg_list):
draw_lattice(cfg, output_prefix + '_' + str(i))
def draw_cfgs(cfg_list, output_prefix='output'):
for i, cfg in enumerate(cfg_list):
draw_cfg(cfg, output_prefix + '_' + str(i))
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--labels', nargs='+',
help='Set of labels in lattice.')
parser.add_argument('-n', '--name', help='Specify filename.', type=str)
if __name__ == '__main__':
args = parser.parse_args()
draw_lattice_from_labels(args.labels, args.name)
from collections import namedtuple
from .node_types import ConnectToExitNode
SavedVariable = namedtuple(
'SavedVariable',
(
'LHS',
'RHS'
)
)
BUILTINS = (
'get',
'Flask',
'run',
'replace',
'read',
'set_cookie',
'make_response',
'SQLAlchemy',
'Column',
'execute',
'sessionmaker',
'Session',
'filter',
'call',
'render_template',
'redirect',
'url_for',
'flash',
'jsonify'
)
class CFG():
def __init__(self, nodes, blackbox_assignments):
self.nodes = nodes
self.blackbox_assignments = blackbox_assignments
def __repr__(self):
output = ''
for x, n in enumerate(self.nodes):
output = ''.join((output, 'Node: ' + str(x) + ' ' + repr(n), '\n\n'))
return output
def __str__(self):
output = ''
for x, n in enumerate(self.nodes):
output = ''.join((output, 'Node: ' + str(x) + ' ' + str(n), '\n\n'))
return output
def return_connection_handler(nodes, exit_node):
"""Connect all return statements to the Exit node."""
for function_body_node in nodes:
if isinstance(function_body_node, ConnectToExitNode):
if exit_node not in function_body_node.outgoing:
function_body_node.connect(exit_node)
import ast
from .alias_helper import (
handle_aliases_in_calls
)
from .ast_helper import (
Arguments,
get_call_names_as_string
)
from .expr_visitor_helper import (
BUILTINS,
CFG,
return_connection_handler,
SavedVariable
)
from .label_visitor import LabelVisitor
from .module_definitions import ModuleDefinitions
from .node_types import (
AssignmentCallNode,
AssignmentNode,
BBorBInode,
ConnectToExitNode,
EntryOrExitNode,
IgnoredNode,
Node,
RestoreNode,
ReturnNode
)
from .right_hand_side_visitor import RHSVisitor
from .stmt_visitor import StmtVisitor
from .stmt_visitor_helper import CALL_IDENTIFIER
class ExprVisitor(StmtVisitor):
def __init__(
self,
node,
project_modules,
local_modules,
filename,
module_definitions=None
):
"""Create an empty CFG."""
self.project_modules = project_modules
self.local_modules = local_modules
self.filenames = [filename]
self.blackbox_assignments = set()
self.nodes = list()
self.function_call_index = 0
self.undecided = False
self.function_names = list()
self.function_return_stack = list()
self.module_definitions_stack = list()
self.prev_nodes_to_avoid = list()
self.last_control_flow_nodes = list()
# Are we already in a module?
if module_definitions:
self.init_function_cfg(node, module_definitions)
else:
self.init_cfg(node)
def init_cfg(self, node):
self.module_definitions_stack.append(ModuleDefinitions(filename=self.filenames[-1]))
entry_node = self.append_node(EntryOrExitNode('Entry module'))
module_statements = self.visit(node)
if not module_statements:
raise Exception('Empty module. It seems that your file is empty,' +
'there is nothing to analyse.')
exit_node = self.append_node(EntryOrExitNode('Exit module'))
if isinstance(module_statements, IgnoredNode):
entry_node.connect(exit_node)
return
first_node = module_statements.first_statement
if CALL_IDENTIFIER not in first_node.label:
entry_node.connect(first_node)
last_nodes = module_statements.last_statements
exit_node.connect_predecessors(last_nodes)
def init_function_cfg(self, node, module_definitions):
self.module_definitions_stack.append(module_definitions)
self.function_names.append(node.name)
self.function_return_stack.append(node.name)
entry_node = self.append_node(EntryOrExitNode('Entry function'))
module_statements = self.stmt_star_handler(node.body)
exit_node = self.append_node(EntryOrExitNode('Exit function'))
if isinstance(module_statements, IgnoredNode):
entry_node.connect(exit_node)
return
first_node = module_statements.first_statement
if CALL_IDENTIFIER not in first_node.label:
entry_node.connect(first_node)
last_nodes = module_statements.last_statements
exit_node.connect_predecessors(last_nodes)
def visit_Yield(self, node):
label = LabelVisitor()
label.visit(node)
try:
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node.value)
except AttributeError:
rhs_visitor.result = 'EmptyYield'
this_function_name = self.function_return_stack[-1]
LHS = 'yield_' + this_function_name
return self.append_node(ReturnNode(
LHS + ' = ' + label.result,
LHS,
node,
rhs_visitor.result,
path=self.filenames[-1])
)
def visit_Attribute(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Name(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_NameConstant(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Str(self, node):
return IgnoredNode()
def visit_Subscript(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Tuple(self, node):
return self.visit_miscelleaneous_node(
node
)
def connect_if_allowed(
self,
previous_node,
node_to_connect_to
):
# e.g.
# while x != 10:
# if x > 0:
# print(x)
# break
# else:
# print('hest')
# print('next') # self.nodes[-1] is print('hest')
#
# So we connect to `while x!= 10` instead
if self.last_control_flow_nodes[-1]:
self.last_control_flow_nodes[-1].connect(node_to_connect_to)
self.last_control_flow_nodes[-1] = None
return
# Except in this case:
#
# if not image_name:
# return 404
# print('foo') # We do not want to connect this line with `return 404`
if previous_node is not self.prev_nodes_to_avoid[-1] and not isinstance(previous_node, ReturnNode):
previous_node.connect(node_to_connect_to)
def save_local_scope(
self,
line_number,
saved_function_call_index
):
"""Save the local scope before entering a function call by saving all the LHS's of assignments so far.
Args:
line_number(int): Of the def of the function call about to be entered into.
saved_function_call_index(int): Unique number for each call.
Returns:
saved_variables(list[SavedVariable])
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
"""
saved_variables = list()
saved_variables_so_far = set()
first_node = None
# Make e.g. save_N_LHS = assignment.LHS for each AssignmentNode
for assignment in [node for node in self.nodes
if (type(node) == AssignmentNode or
type(node) == AssignmentCallNode or
type(Node) == BBorBInode)]: # type() is used on purpose here
if assignment.left_hand_side in saved_variables_so_far:
continue
saved_variables_so_far.add(assignment.left_hand_side)
save_name = 'save_{}_{}'.format(saved_function_call_index, assignment.left_hand_side)
previous_node = self.nodes[-1]
saved_scope_node = RestoreNode(
save_name + ' = ' + assignment.left_hand_side,
save_name,
[assignment.left_hand_side],
line_number=line_number,
path=self.filenames[-1]
)
if not first_node:
first_node = saved_scope_node
self.nodes.append(saved_scope_node)
# Save LHS
saved_variables.append(SavedVariable(LHS=save_name,
RHS=assignment.left_hand_side))
self.connect_if_allowed(previous_node, saved_scope_node)
return (saved_variables, first_node)
def save_def_args_in_temp(
self,
call_args,
def_args,
line_number,
saved_function_call_index,
first_node
):
"""Save the arguments of the definition being called. Visit the arguments if they're calls.
Args:
call_args(list[ast.Name]): Of the call being made.
def_args(ast_helper.Arguments): Of the definition being called.
line_number(int): Of the call being made.
saved_function_call_index(int): Unique number for each call.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
Returns:
args_mapping(dict): A mapping of call argument to definition argument.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
"""
args_mapping = dict()
last_return_value_of_nested_call = None
# Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument
for i, call_arg in enumerate(call_args):
# If this results in an IndexError it is invalid Python
def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i]
return_value_of_nested_call = None
if isinstance(call_arg, ast.Call):
return_value_of_nested_call = self.visit(call_arg)
restore_node = RestoreNode(
def_arg_temp_name + ' = ' + return_value_of_nested_call.left_hand_side,
def_arg_temp_name,
[return_value_of_nested_call.left_hand_side],
line_number=line_number,
path=self.filenames[-1]
)
if return_value_of_nested_call in self.blackbox_assignments:
self.blackbox_assignments.add(restore_node)
else:
call_arg_label_visitor = LabelVisitor()
call_arg_label_visitor.visit(call_arg)
call_arg_rhs_visitor = RHSVisitor()
call_arg_rhs_visitor.visit(call_arg)
restore_node = RestoreNode(
def_arg_temp_name + ' = ' + call_arg_label_visitor.result,
def_arg_temp_name,
call_arg_rhs_visitor.result,
line_number=line_number,
path=self.filenames[-1]
)
# If there are no saved variables, then this is the first node
if not first_node:
first_node = restore_node
if isinstance(call_arg, ast.Call):
if last_return_value_of_nested_call:
# connect inner to other_inner in e.g. `outer(inner(image_name), other_inner(image_name))`
if isinstance(return_value_of_nested_call, BBorBInode):
last_return_value_of_nested_call.connect(return_value_of_nested_call)
else:
last_return_value_of_nested_call.connect(return_value_of_nested_call.first_node)
else:
# I should only set this once per loop, inner in e.g. `outer(inner(image_name), other_inner(image_name))`
# (inner_most_call is used when predecessor is a ControlFlowNode in connect_control_flow_node)
if isinstance(return_value_of_nested_call, BBorBInode):
first_node.inner_most_call = return_value_of_nested_call
else:
first_node.inner_most_call = return_value_of_nested_call.first_node
# We purposefully should not set this as the first_node of return_value_of_nested_call, last makes sense
last_return_value_of_nested_call = return_value_of_nested_call
self.connect_if_allowed(self.nodes[-1], restore_node)
self.nodes.append(restore_node)
if isinstance(call_arg, ast.Call):
args_mapping[return_value_of_nested_call.left_hand_side] = def_args[i]
else:
args_mapping[def_args[i]] = call_arg_label_visitor.result
return (args_mapping, first_node)
def create_local_scope_from_def_args(
self,
call_args,
def_args,
line_number,
saved_function_call_index
):
"""Create the local scope before entering the body of a function call.
Args:
call_args(list[ast.Name]): Of the call being made.
def_args(ast_helper.Arguments): Of the definition being called.
line_number(int): Of the def of the function call about to be entered into.
saved_function_call_index(int): Unique number for each call.
Note: We do not need a connect_if_allowed because of the
preceding call to save_def_args_in_temp.
"""
# Create e.g. def_arg1 = temp_N_def_arg1 for each argument
for i in range(len(call_args)):
def_arg_local_name = def_args[i]
def_arg_temp_name = 'temp_' + str(saved_function_call_index) + '_' + def_args[i]
local_scope_node = RestoreNode(
def_arg_local_name + ' = ' + def_arg_temp_name,
def_arg_local_name,
[def_arg_temp_name],
line_number=line_number,
path=self.filenames[-1]
)
# Chain the local scope nodes together
self.nodes[-1].connect(local_scope_node)
self.nodes.append(local_scope_node)
def visit_and_get_function_nodes(
self,
definition,
first_node
):
"""Visits the nodes of a user defined function.
Args:
definition(LocalModuleDefinition): Definition of the function being added.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
Returns:
the_new_nodes(list[Node]): The nodes added while visiting the function.
first_node(EntryOrExitNode or None or RestoreNode): Used to connect previous statements to this function.
"""
len_before_visiting_func = len(self.nodes)
previous_node = self.nodes[-1]
entry_node = self.append_node(EntryOrExitNode('Function Entry ' +
definition.name))
if not first_node:
first_node = entry_node
self.connect_if_allowed(previous_node, entry_node)
function_body_connect_statements = self.stmt_star_handler(definition.node.body)
entry_node.connect(function_body_connect_statements.first_statement)
exit_node = self.append_node(EntryOrExitNode('Exit ' + definition.name))
exit_node.connect_predecessors(function_body_connect_statements.last_statements)
the_new_nodes = self.nodes[len_before_visiting_func:]
return_connection_handler(the_new_nodes, exit_node)
return (the_new_nodes, first_node)
def restore_saved_local_scope(
self,
saved_variables,
args_mapping,
line_number
):
"""Restore the previously saved variables to their original values.
Args:
saved_variables(list[SavedVariable])
args_mapping(dict): A mapping of call argument to definition argument.
line_number(int): Of the def of the function call about to be entered into.
Note: We do not need connect_if_allowed because of the
preceding call to save_local_scope.
"""
restore_nodes = list()
for var in saved_variables:
# Is var.RHS a call argument?
if var.RHS in args_mapping:
# If so, use the corresponding definition argument for the RHS of the label.
restore_nodes.append(RestoreNode(
var.RHS + ' = ' + args_mapping[var.RHS],
var.RHS,
[var.LHS],
line_number=line_number,
path=self.filenames[-1]
))
else:
# Create a node for e.g. foo = save_1_foo
restore_nodes.append(RestoreNode(
var.RHS + ' = ' + var.LHS,
var.RHS,
[var.LHS],
line_number=line_number,
path=self.filenames[-1]
))
# Chain the restore nodes
for node, successor in zip(restore_nodes, restore_nodes[1:]):
node.connect(successor)
if restore_nodes:
# Connect the last node to the first restore node
self.nodes[-1].connect(restore_nodes[0])
self.nodes.extend(restore_nodes)
return restore_nodes
def return_handler(
self,
call_node,
function_nodes,
saved_function_call_index,
first_node
):
"""Handle the return from a function during a function call.
Args:
call_node(ast.Call) : The node that calls the definition.
function_nodes(list[Node]): List of nodes of the function being called.
saved_function_call_index(int): Unique number for each call.
first_node(EntryOrExitNode or RestoreNode): Used to connect previous statements to this function.
"""
for node in function_nodes:
# Only `Return`s and `Raise`s can be of type ConnectToExitNode
if isinstance(node, ConnectToExitNode):
# Create e.g. ~call_1 = ret_func_foo RestoreNode
LHS = CALL_IDENTIFIER + 'call_' + str(saved_function_call_index)
RHS = 'ret_' + get_call_names_as_string(call_node.func)
return_node = RestoreNode(
LHS + ' = ' + RHS,
LHS,
[RHS],
line_number=call_node.lineno,
path=self.filenames[-1]
)
return_node.first_node = first_node
self.nodes[-1].connect(return_node)
self.nodes.append(return_node)
return
def process_function(self, call_node, definition):
"""Processes a user defined function when it is called.
Increments self.function_call_index each time it is called, we can refer to it as N in the comments.
Make e.g. save_N_LHS = assignment.LHS for each AssignmentNode. (save_local_scope)
Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument.
Visit the arguments if they're calls. (save_def_args_in_temp)
Create e.g. def_arg1 = temp_N_def_arg1 for each argument. (create_local_scope_from_def_args)
Visit and get function nodes. (visit_and_get_function_nodes)
Loop through each save_N_LHS node and create an e.g.
foo = save_1_foo or, if foo was a call arg, foo = arg_mapping[foo]. (restore_saved_local_scope)
Create e.g. ~call_1 = ret_func_foo RestoreNode. (return_handler)
Notes:
Page 31 in the original thesis, but changed a little.
We don't have to return the ~call_1 = ret_func_foo RestoreNode made in return_handler,
because it's the last node anyway, that we return in this function.
e.g. ret_func_foo gets assigned to visit_Return.
Args:
call_node(ast.Call) : The node that calls the definition.
definition(LocalModuleDefinition): Definition of the function being called.
Returns:
Last node in self.nodes, probably the return of the function appended to self.nodes in return_handler.
"""
self.function_call_index += 1
saved_function_call_index = self.function_call_index
def_node = definition.node
saved_variables, first_node = self.save_local_scope(
def_node.lineno,
saved_function_call_index
)
args_mapping, first_node = self.save_def_args_in_temp(
call_node.args,
Arguments(def_node.args),
call_node.lineno,
saved_function_call_index,
first_node
)
self.filenames.append(definition.path)
self.create_local_scope_from_def_args(
call_node.args,
Arguments(def_node.args),
def_node.lineno,
saved_function_call_index
)
function_nodes, first_node = self.visit_and_get_function_nodes(
definition,
first_node
)
self.filenames.pop() # Should really probably move after restore_saved_local_scope!!!
self.restore_saved_local_scope(
saved_variables,
args_mapping,
def_node.lineno
)
self.return_handler(
call_node,
function_nodes,
saved_function_call_index,
first_node
)
self.function_return_stack.pop()
return self.nodes[-1]
def visit_Call(self, node):
_id = get_call_names_as_string(node.func)
local_definitions = self.module_definitions_stack[-1]
alias = handle_aliases_in_calls(_id, local_definitions.import_alias_mapping)
if alias:
definition = local_definitions.get_definition(alias)
else:
definition = local_definitions.get_definition(_id)
# e.g. "request.args.get" -> "get"
last_attribute = _id.rpartition('.')[-1]
if definition:
if isinstance(definition.node, ast.ClassDef):
self.add_blackbox_or_builtin_call(node, blackbox=False)
elif isinstance(definition.node, ast.FunctionDef):
self.undecided = False
self.function_return_stack.append(_id)
return self.process_function(node, definition)
else:
raise Exception('Definition was neither FunctionDef or ' +
'ClassDef, cannot add the function ')
elif last_attribute not in BUILTINS:
# Mark the call as a blackbox because we don't have the definition
return self.add_blackbox_or_builtin_call(node, blackbox=True)
return self.add_blackbox_or_builtin_call(node, blackbox=False)
def make_cfg(
node,
project_modules,
local_modules,
filename,
module_definitions=None
):
visitor = ExprVisitor(
node,
project_modules,
local_modules, filename,
module_definitions
)
return CFG(
visitor.nodes,
visitor.blackbox_assignments
)
"""This module implements the fixed point algorithm."""
from .constraint_table import constraint_table
class FixedPointAnalysis():
"""Run the fix point analysis."""
def __init__(self, cfg, analysis):
"""Fixed point analysis.
Analysis must be a dataflow analysis containing a 'fixpointmethod'
method that analyses one CFG node."""
self.analysis = analysis(cfg)
self.cfg = cfg
def fixpoint_runner(self):
"""Work list algorithm that runs the fixpoint algorithm."""
q = self.cfg.nodes
while q != []:
x_i = constraint_table[q[0]] # x_i = q[0].old_constraint
self.analysis.fixpointmethod(q[0]) # y = F_i(x_1, ..., x_n);
y = constraint_table[q[0]] # y = q[0].new_constraint
if not self.analysis.equal(y, x_i):
for node in self.analysis.dep(q[0]): # for (v in dep(v_i))
q.append(node) # q.append(v):
constraint_table[q[0]] = y # q[0].old_constraint = q[0].new_constraint # x_i = y
q = q[1:] # q = q.tail() # The list minus the head
def analyse(cfg_list, *, analysis_type):
"""Analyse a list of control flow graphs with a given analysis type."""
for cfg in cfg_list:
analysis = FixedPointAnalysis(cfg, analysis_type)
analysis.fixpoint_runner()
"""A generic framework adaptor that leaves route criteria to the caller."""
import ast
from .ast_helper import Arguments
from .expr_visitor import make_cfg
from .module_definitions import project_definitions
from .node_types import (
AssignmentNode,
TaintedNode
)
class FrameworkAdaptor():
"""An engine that uses the template pattern to find all
entry points in a framework and then taints their arguments.
"""
def __init__(self, cfg_list, project_modules, local_modules, is_route_function):
self.cfg_list = cfg_list
self.project_modules = project_modules
self.local_modules = local_modules
self.is_route_function = is_route_function
self.run()
def get_func_cfg_with_tainted_args(self, definition):
"""Build a function cfg and return it, with all arguments tainted."""
func_cfg = make_cfg(
definition.node,
self.project_modules,
self.local_modules,
definition.path,
definition.module_definitions
)
args = Arguments(definition.node.args)
if args:
function_entry_node = func_cfg.nodes[0]
function_entry_node.outgoing = list()
first_node_after_args = func_cfg.nodes[1]
first_node_after_args.ingoing = list()
# We are just going to give all the tainted args the lineno of the def
definition_lineno = definition.node.lineno
# Taint all the arguments
for i, arg in enumerate(args):
node_type = TaintedNode
if i == 0 and arg == 'self':
node_type = AssignmentNode
arg_node = node_type(
label=arg,
left_hand_side=arg,
ast_node=None,
right_hand_side_variables=[],
line_number=definition_lineno,
path=definition.path
)
function_entry_node.connect(arg_node)
# 1 and not 0 so that Entry Node remains first in the list
func_cfg.nodes.insert(1, arg_node)
arg_node.connect(first_node_after_args)
return func_cfg
def find_route_functions_taint_args(self):
"""Find all route functions and taint all of their arguments.
Yields:
CFG of each route function, with args marked as tainted.
"""
for definition in _get_func_nodes():
if self.is_route_function(definition.node):
yield self.get_func_cfg_with_tainted_args(definition)
def run(self):
"""Run find_route_functions_taint_args on each CFG."""
function_cfgs = list()
for _ in self.cfg_list:
function_cfgs.extend(self.find_route_functions_taint_args())
self.cfg_list.extend(function_cfgs)
def _get_func_nodes():
"""Get all function nodes."""
return [definition for definition in project_definitions.values()
if isinstance(definition.node, ast.FunctionDef)]
"""Provides helper functions that help with determining if a function is a route function."""
import ast
from .ast_helper import get_call_names
def is_function(function):
"""Always returns true because arg is always a function."""
return True
def is_flask_route_function(ast_node):
"""Check whether function uses a route decorator."""
for decorator in ast_node.decorator_list:
if isinstance(decorator, ast.Call):
if _get_last_of_iterable(get_call_names(decorator.func)) == 'route':
return True
return False
def is_django_view_function(ast_node):
if len(ast_node.args.args):
first_arg_name = ast_node.args.args[0].arg
return first_arg_name == 'request'
return False
def is_function_without_leading_(ast_node):
if ast_node.name.startswith('_'):
return False
return True
def _get_last_of_iterable(iterable):
"""Get last element of iterable."""
item = None
for item in iterable:
pass
return item
import re
import requests
import time
from abc import ABCMeta, abstractmethod
from datetime import date, datetime, timedelta
from . import repo_runner
from .reaching_definitions_taint import ReachingDefinitionsTaintAnalysis
from .repo_runner import add_repo_to_csv, NoEntryPathError
from .save import save_repo_scan
DEFAULT_TIMEOUT_IN_SECONDS = 60
GITHUB_API_URL = 'https://api.github.com'
GITHUB_OAUTH_TOKEN = None
NUMBER_OF_REQUESTS_ALLOWED_PER_MINUTE = 30 # Rate limit is 10 and 30 with auth
SEARCH_CODE_URL = GITHUB_API_URL + '/search/code'
SEARCH_REPO_URL = GITHUB_API_URL + '/search/repositories'
def set_github_api_token():
global GITHUB_OAUTH_TOKEN
try:
GITHUB_OAUTH_TOKEN = open('github_access_token.pyt',
'r').read().strip()
except FileNotFoundError:
print('Insert your GitHub access token'
' in the github_access_token.pyt file in the pyt package'
' if you want to use GitHub search.')
exit(0)
class Languages:
_prefix = 'language:'
python = _prefix + 'python'
javascript = _prefix + 'javascript'
# add others here
class Query:
def __init__(self, base_url, search_string,
language=None, repo=None, time_interval=None, per_page=100):
repo = self._repo_parameter(repo)
time_interval = self._time_interval_parameter(time_interval)
search_string = self._search_parameter(search_string)
per_page = self._per_page_parameter(per_page)
parameters = self._construct_parameters([search_string,
language,
repo,
time_interval,
per_page])
self.query_string = self._construct_query(base_url, parameters)
def _construct_query(self, base_url, parameters):
query = base_url
query += '+'.join(parameters)
return query
def _construct_parameters(self, parameters):
r = list()
for p in parameters:
if p:
r.append(p)
return r
def _search_parameter(self, search_string):
return '?q="' + search_string + '"'
def _repo_parameter(self, repo):
if repo:
return 'repo:' + repo.name
else:
return None
def _time_interval_parameter(self, created):
if created:
p = re.compile('\d\d\d\d-\d\d-\d\d \.\. \d\d\d\d-\d\d-\d\d')
m = p.match(created)
if m.group():
return 'created:"' + m.group() + '"'
else:
print('The time interval parameter should be '
'of the form: "YYYY-MM-DD .. YYYY-MM-DD"')
exit(1)
return None
def _per_page_parameter(self, per_page):
if per_page > 100:
print('The GitHub api does not allow pages with over 100 results.')
exit(1)
return '&per_page={}'.format(per_page)
class IncompleteResultsError(Exception):
pass
class RequestCounter:
def __init__(self, timeout=DEFAULT_TIMEOUT_IN_SECONDS):
self.timeout_in_seconds = timeout # timeout in seconds
self.counter = list()
def append(self, request_time):
if len(self.counter) < NUMBER_OF_REQUESTS_ALLOWED_PER_MINUTE:
self.counter.append(request_time)
else:
delta = request_time - self.counter[0]
if delta.seconds < self.timeout_in_seconds:
print('Maximum requests "{}" reached'
' timing out for {} seconds.'
.format(len(self.counter),
self.timeout_in_seconds - delta.seconds))
self.timeout(self.timeout_in_seconds - delta.seconds)
self.counter.pop(0) # pop index 0
self.counter.append(datetime.now())
else:
self.counter.pop(0) # pop index 0
self.counter.append(request_time)
def timeout(self, time_in_seconds=DEFAULT_TIMEOUT_IN_SECONDS):
time.sleep(time_in_seconds)
class Search(metaclass=ABCMeta):
request_counter = RequestCounter()
def __init__(self, query):
self.total_count = None
self.incomplete_results = None
self.results = list()
self._request(query.query_string)
def _request(self, query_string):
Search.request_counter.append(datetime.now())
print('Making request: {}'.format(query_string))
headers = {'Authorization': 'token ' + GITHUB_OAUTH_TOKEN}
r = requests.get(query_string, headers=headers)
json = r.json()
if r.status_code != 200:
print('Bad request:')
print(r.status_code)
print(json)
Search.request_counter.timeout()
self._request(query_string)
return
self.total_count = json['total_count']
print('Number of results: {}.'.format(self.total_count))
self.incomplete_results = json['incomplete_results']
if self.incomplete_results:
raise IncompleteResultsError()
self.parse_results(json['items'])
@abstractmethod
def parse_results(self, json_results):
pass
class SearchRepo(Search):
def parse_results(self, json_results):
for item in json_results:
self.results.append(Repo(item))
class SearchCode(Search):
def parse_results(self, json_results):
for item in json_results:
self.results.append(File(item))
class File:
def __init__(self, json):
self.name = json['name']
self.repo = Repo(json['repository'])
class Repo:
def __init__(self, json):
self.url = json['html_url']
self.name = json['full_name']
def get_dates(start_date, end_date=date.today(), interval=7):
delta = end_date - start_date
for i in range(delta.days // interval):
yield (start_date + timedelta(days=(i * interval) - interval),
start_date + timedelta(days=i * interval))
else:
# Take care of the remainder of days
yield (start_date + timedelta(days=i * interval),
start_date + timedelta(days=i * interval +
interval +
delta.days % interval))
def scan_github(search_string, start_date, analysis_type, analyse_repo_func, csv_path, ui_mode, other_args):
analyse_repo = analyse_repo_func
for d in get_dates(start_date, interval=7):
q = Query(SEARCH_REPO_URL, search_string,
language=Languages.python,
time_interval=str(d[0]) + ' .. ' + str(d[1]),
per_page=100)
s = SearchRepo(q)
for repo in s.results:
q = Query(SEARCH_CODE_URL, 'app = Flask(__name__)',
Languages.python, repo)
s = SearchCode(q)
if s.results:
r = repo_runner.Repo(repo.url)
try:
r.clone()
except NoEntryPathError as err:
save_repo_scan(repo, r.path, vulnerabilities=None, error=err)
continue
except:
save_repo_scan(repo, r.path, vulnerabilities=None, error='Other Error Unknown while cloning :-(')
continue
try:
vulnerabilities = analyse_repo(other_args, r, analysis_type, ui_mode)
if vulnerabilities:
save_repo_scan(repo, r.path, vulnerabilities)
add_repo_to_csv(csv_path, r)
else:
save_repo_scan(repo, r.path, vulnerabilities=None)
r.clean_up()
except SyntaxError as err:
save_repo_scan(repo, r.path, vulnerabilities=None, error=err)
except IOError as err:
save_repo_scan(repo, r.path, vulnerabilities=None, error=err)
except AttributeError as err:
save_repo_scan(repo, r.path, vulnerabilities=None, error=err)
except:
save_repo_scan(repo, r.path, vulnerabilities=None, error='Other Error Unknown :-(')
if __name__ == '__main__':
for x in get_dates(date(2010, 1, 1), interval=93):
print(x)
exit()
scan_github('flask', ReachingDefinitionsTaintAnalysis)
exit()
q = Query(SEARCH_REPO_URL, 'flask')
s = SearchRepo(q)
for repo in s.results[:3]:
q = Query(SEARCH_CODE_URL, 'app = Flask(__name__)', Languages.python, repo)
s = SearchCode(q)
r = repo_runner.Repo(repo.url)
r.clone()
print(r.path)
r.clean_up()
print(repo.name)
print(len(s.results))
print([f.name for f in s.results])
exit()
r = RequestCounter('test', timeout=2)
for x in range(15):
r.append(datetime.now())
exit()
dates = get_dates(date(2010, 1, 1))
for date in dates:
q = Query(SEARCH_REPO_URL, 'flask',
time_interval=str(date) + ' .. ' + str(date))
print(q.query_string)
exit()
s = SearchRepo(q)
print(s.total_count)
print(s.incomplete_results)
print([r.URL for r in s.results])
q = Query(SEARCH_CODE_URL, 'import flask', Languages.python, s.results[0])
s = SearchCode(q)
#print(s.total_count)
#print(s.incomplete_results)
#print([f.name for f in s.results])
import ast
class LabelVisitor(ast.NodeVisitor):
def __init__(self):
self.result = ''
def handle_comma_separated(self, comma_separated_list):
if comma_separated_list:
for element in range(len(comma_separated_list)-1):
self.visit(comma_separated_list[element])
self.result += ', '
self.visit(comma_separated_list[-1])
def visit_Tuple(self, node):
self.result += '('
self.handle_comma_separated(node.elts)
self.result += ')'
def visit_List(self, node):
self.result += '['
self.handle_comma_separated(node.elts)
self.result += ']'
def visit_Raise(self, node):
self.result += 'raise'
if node.exc:
self.result += ' '
self.visit(node.exc)
if node.cause:
self.result += ' from '
self.visit(node.cause)
def visit_withitem(self, node):
self.result += 'with '
self.visit(node.context_expr)
if node.optional_vars:
self.result += ' as '
self.visit(node.optional_vars)
def visit_Return(self, node):
if node.value:
self.visit(node.value)
def visit_Assign(self, node):
for target in node.targets:
self.visit(target)
self.result = ' '.join((self.result, '='))
self.insert_space()
self.visit(node.value)
def visit_AugAssign(self, node):
self.visit(node.target)
self.insert_space()
self.visit(node.op)
self.result += '='
self.insert_space()
self.visit(node.value)
def visit_Compare(self, node):
self.visit(node.left)
self.insert_space()
for op, com in zip(node.ops, node.comparators):
self.visit(op)
self.insert_space()
self.visit(com)
self.insert_space()
self.result = self.result.rstrip()
def visit_BinOp(self, node):
self.visit(node.left)
self.insert_space()
self.visit(node.op)
self.insert_space()
self.visit(node.right)
def visit_UnaryOp(self, node):
self.visit(node.op)
self.visit(node.operand)
def visit_BoolOp(self, node):
for i, value in enumerate(node.values):
if i == len(node.values)-1:
self.visit(value)
else:
self.visit(value)
self.visit(node.op)
def comprehensions(self, node):
self.visit(node.elt)
for expression in node.generators:
self.result += ' for '
self.visit(expression.target)
self.result += ' in '
self.visit(expression.iter)
def visit_GeneratorExp(self, node):
self.result += '('
self.comprehensions(node)
self.result += ')'
def visit_ListComp(self, node):
self.result += '['
self.comprehensions(node)
self.result += ']'
def visit_SetComp(self, node):
self.result += '{'
self.comprehensions(node)
self.result += '}'
def visit_DictComp(self, node):
self.result += '{'
self.visit(node.key)
self.result += ' : '
self.visit(node.value)
for expression in node.generators:
self.result += ' for '
self.visit(expression.target)
self.result += ' in '
self.visit(expression.iter)
self.result += '}'
def visit_Attribute(self, node):
self.visit(node.value)
self.result += '.'
self.result += node.attr
def visit_Call(self, node):
self.visit(node.func)
self.result += '('
if node.keywords and node.args:
self.handle_comma_separated(node.args)
self.result += ','
else:
self.handle_comma_separated(node.args)
self.handle_comma_separated(node.keywords)
self.result += ')'
def visit_keyword(self, node):
if node.arg:
self.result += node.arg
self.result += '='
self.visit(node.value)
def insert_space(self):
self.result += ' '
def visit_NameConstant(self, node):
self.result += str(node.value)
def visit_Subscript(self, node):
self.visit(node.value)
self.result += '['
self.slicev(node.slice)
self.result += ']'
def slicev(self, node):
if isinstance(node, ast.Slice):
if node.lower:
self.visit(node.lower)
if node.upper:
self.visit(node.upper)
if node.step:
self.visit(node.step)
elif isinstance(node, ast.ExtSlice):
if node.dims:
for d in node.dims:
self.visit(d)
else:
self.visit(node.value)
# operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv
def visit_Add(self, node):
self.result += '+'
def visit_Sub(self, node):
self.result += '-'
def visit_Mult(self, node):
self.result += '*'
def vist_MatMult(self, node):
self.result += 'x'
def visit_Div(self, node):
self.result += '/'
def visit_Mod(self, node):
self.result += '%'
def visit_Pow(self, node):
self.result += '**'
def visit_LShift(self, node):
self.result += '<<'
def visit_RShift(self, node):
self.result += '>>'
def visit_BitOr(self, node):
self.result += '|'
def visit_BitXor(self, node):
self.result += '^'
def visit_BitAnd(self, node):
self.result += '&'
def visit_FloorDiv(self, node):
self.result += '//'
# cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
def visit_Eq(self, node):
self.result += '=='
def visit_Gt(self, node):
self.result += '>'
def visit_Lt(self, node):
self.result += '<'
def visit_NotEq(self, node):
self.result += '!='
def visit_GtE(self, node):
self.result += '>='
def visit_LtE(self, node):
self.result += '<='
def visit_Is(self, node):
self.result += 'is'
def visit_IsNot(self, node):
self.result += 'is not'
def visit_In(self, node):
self.result += 'in'
def visit_NotIn(self, node):
self.result += 'not in'
# unaryop = Invert | Not | UAdd | USub
def visit_Invert(self, node):
self.result += '~'
def visit_Not(self, node):
self.result += 'not '
def visit_UAdd(self, node):
self.result += '+'
def visit_USub(self, node):
self.result += '-'
# boolop = And | Or
def visit_And(self, node):
self.result += ' and '
def visit_Or(self, node):
self.result += ' or '
def visit_Num(self, node):
self.result += str(node.n)
def visit_Name(self, node):
self.result += node.id
def visit_Str(self, node):
self.result += "'" + node.s + "'"
from .constraint_table import constraint_table
class Lattice:
def __init__(self, cfg_nodes, analysis_type):
self.el2bv = dict() # Element to bitvector dictionary
self.bv2el = list() # Bitvector to element list
for i, e in enumerate(analysis_type.get_lattice_elements(cfg_nodes)):
# Give each element a unique shift of 1
self.el2bv[e] = 0b1 << i
self.bv2el.insert(0, e)
def get_elements(self, number):
if number == 0:
return []
elements = list()
# Turn number into a binary string of length len(self.bv2el)
binary_string = format(number,
'0' + str(len(self.bv2el)) + 'b')
for i, bit in enumerate(binary_string):
if bit == '1':
elements.append(self.bv2el[i])
return elements
def in_constraint(self, node1, node2):
"""Checks if node1 is in node2's constraints
For instance, if node1 = 010 and node2 = 110:
010 & 110 = 010 -> has the element."""
constraint = constraint_table[node2]
if constraint == 0b0:
return False
try:
value = self.el2bv[node1]
except KeyError:
return False
return constraint & value != 0
def print_lattice(cfg_list, analysis_type):
nodes = list()
for cfg in cfg_list:
nodes.extend(cfg.nodes)
l = Lattice(nodes, analysis_type)
print('Lattice:')
for k, v in l.el2bv.items():
print(str(k) + ': ' + str(v))
return l
import ast
from .analysis_base import AnalysisBase
from .ast_helper import get_call_names_as_string
from .constraint_table import (
constraint_join,
constraint_table
)
from .lattice import Lattice
from .node_types import (
AssignmentNode,
BBorBInode,
EntryOrExitNode
)
from .vars_visitor import VarsVisitor
class LivenessAnalysis(AnalysisBase):
"""Reaching definitions analysis rules implemented."""
def __init__(self, cfg):
super().__init__(cfg)
def join(self, cfg_node):
"""Joins all constraints of the ingoing nodes and returns them.
This represents the JOIN auxiliary definition from Schwartzbach."""
return constraint_join(cfg_node.outgoing)
def is_output(self, cfg_node):
if isinstance(cfg_node.ast_node, ast.Call):
call_name = get_call_names_as_string(cfg_node.ast_node.func)
if 'print' in call_name:
return True
return False
def is_condition(self, cfg_node):
if isinstance(cfg_node.ast_node, (ast.If, ast.While)):
return True
elif self.is_output(cfg_node):
return True
return False
def remove_id_assignment(self, JOIN, cfg_node):
lvars = list()
if isinstance(cfg_node, BBorBInode):
lvars.append(cfg_node.left_hand_side)
else:
try:
for expr in cfg_node.ast_node.targets:
vv = VarsVisitor()
vv.visit(expr)
lvars.extend(vv.result)
except AttributeError: # If it is AugAssign
vv = VarsVisitor()
vv.visit(cfg_node.ast_node.target)
lvars.extend(vv.result)
for var in lvars:
if var in self.lattice.get_elements(JOIN):
# Remove var from JOIN
JOIN = JOIN ^ self.lattice.el2bv[var]
return JOIN
def add_vars_assignment(self, JOIN, cfg_node):
rvars = list()
if isinstance(cfg_node, BBorBInode):
# A conscience decision was made not to include e.g. ~call_N's in RHS vars
rvars.extend(cfg_node.right_hand_side_variables)
else:
vv = VarsVisitor()
vv.visit(cfg_node.ast_node.value)
rvars.extend(vv.result)
for var in rvars:
# Add var to JOIN
JOIN = JOIN | self.lattice.el2bv[var]
return JOIN
def add_vars_conditional(self, JOIN, cfg_node):
varse = None
if isinstance(cfg_node.ast_node, ast.While):
vv = VarsVisitor()
vv.visit(cfg_node.ast_node.test)
varse = vv.result
elif self.is_output(cfg_node):
vv = VarsVisitor()
vv.visit(cfg_node.ast_node)
varse = vv.result
elif isinstance(cfg_node.ast_node, ast.If):
vv = VarsVisitor()
vv.visit(cfg_node.ast_node.test)
varse = vv.result
for var in varse:
JOIN = JOIN | self.lattice.el2bv[var]
return JOIN
def fixpointmethod(self, cfg_node):
if isinstance(cfg_node, EntryOrExitNode) and 'Exit' in cfg_node.label:
constraint_table[cfg_node] = 0
elif isinstance(cfg_node, AssignmentNode):
JOIN = self.join(cfg_node)
JOIN = self.remove_id_assignment(JOIN, cfg_node)
JOIN = self.add_vars_assignment(JOIN, cfg_node)
constraint_table[cfg_node] = JOIN
elif self.is_condition(cfg_node):
JOIN = self.join(cfg_node)
JOIN = self.add_vars_conditional(JOIN, cfg_node)
constraint_table[cfg_node] = JOIN
else:
constraint_table[cfg_node] = self.join(cfg_node)
def dep(self, q_1):
"""Represents the dep mapping from Schwartzbach."""
for node in q_1.outgoing:
yield node
def get_lattice_elements(cfg_nodes):
"""Returns all variables as they are the only lattice elements
in the liveness analysis.
This is a static method which is overwritten from the base class."""
lattice_elements = set() # set() to avoid duplicates
for node in (node for node in cfg_nodes if node.ast_node):
vv = VarsVisitor()
vv.visit(node.ast_node)
for var in vv.result:
lattice_elements.add(var)
return lattice_elements
def equal(self, value, other):
return value == other
def build_lattice(self, cfg):
self.lattice = Lattice(cfg.nodes, LivenessAnalysis)
"""This module handles module definitions
which basically is a list of module definition."""
import ast
# Contains all project definitions for a program run:
project_definitions = dict()
class ModuleDefinition():
"""Handling of a definition."""
module_definitions = None
name = None
node = None
path = None
def __init__(
self,
local_module_definitions,
name,
parent_module_name,
path
):
self.module_definitions = local_module_definitions
self.parent_module_name = parent_module_name
self.path = path
if parent_module_name:
if isinstance(parent_module_name, ast.alias):
self.name = parent_module_name.name + '.' + name
else:
self.name = parent_module_name + '.' + name
else:
self.name = name
def __str__(self):
name = 'NoName'
node = 'NoNode'
if self.name:
name = self.name
if self.node:
node = str(self.node)
return "Path:" + self.path + " " + self.__class__.__name__ + ': ' + ';'.join((name, node))
class LocalModuleDefinition(ModuleDefinition):
"""A local definition."""
pass
class ModuleDefinitions():
"""A collection of module definition.
Adds to the project definitions list.
"""
def __init__(self, import_names=None, module_name=None, is_init=False, filename=None):
"""Optionally set import names and module name.
Module name should only be set when it is a normal import statement.
"""
self.import_names = import_names
# module_name is sometimes ast.alias or a string
self.module_name = module_name
self.is_init = is_init
self.filename = filename
self.definitions = list()
self.classes = list()
self.import_alias_mapping = dict()
def append_if_local_or_in_imports(self, definition):
"""Add definition to list.
Handles local definitions and adds to project_definitions.
"""
if isinstance(definition, LocalModuleDefinition):
self.definitions.append(definition)
elif self.import_names == ["*"]:
self.definitions.append(definition)
elif self.import_names and definition.name in self.import_names:
self.definitions.append(definition)
elif (self.import_alias_mapping and definition.name in
self.import_alias_mapping.values()):
self.definitions.append(definition)
if definition.parent_module_name:
self.definitions.append(definition)
if definition.node not in project_definitions:
project_definitions[definition.node] = definition
def get_definition(self, name):
"""Get definitions by name."""
for definition in self.definitions:
if definition.name == name:
return definition
def set_definition_node(self, node, name):
"""Set definition by name."""
definition = self.get_definition(name)
if definition:
definition.node = node
def __str__(self):
module = 'NoModuleName'
if self.module_name:
module = self.module_name
if self.definitions:
if isinstance(module, ast.alias):
return (
'Definitions: "' + '", "'
.join([str(definition) for definition in self.definitions]) +
'" and module_name: ' + module.name +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
return (
'Definitions: "' + '", "'
.join([str(definition) for definition in self.definitions]) +
'" and module_name: ' + module +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
else:
if isinstance(module, ast.alias):
return (
'import_names is ' + str(self.import_names) +
' No Definitions, module_name: ' + str(module.name) +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
return (
'import_names is ' + str(self.import_names) +
' No Definitions, module_name: ' + str(module) +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
"""This module contains all of the CFG nodes types."""
from collections import namedtuple
from .label_visitor import LabelVisitor
ControlFlowNode = namedtuple(
'ControlFlowNode',
(
'test',
'last_nodes',
'break_statements'
)
)
class IgnoredNode():
"""Ignored Node sent from an ast node that should not return anything."""
pass
class ConnectToExitNode():
"""A common type between raise's and return's, used in return_handler."""
pass
class Node():
"""A Control Flow Graph node that contains a list of
ingoing and outgoing nodes and a list of its variables."""
def __init__(self, label, ast_node, *, line_number=None, path):
"""Create a Node that can be used in a CFG.
Args:
label(str): The label of the node, describing its expression.
line_number(Optional[int]): The line of the expression of the Node.
"""
self.label = label
self.ast_node = ast_node
if line_number:
self.line_number = line_number
elif ast_node:
self.line_number = ast_node.lineno
else:
self.line_number = None
self.path = path
self.ingoing = list()
self.outgoing = list()
def as_dict(self):
return {
'label': self.label.encode('utf-8').decode('utf-8'),
'line_number': self.line_number,
'path': self.path,
}
def connect(self, successor):
"""Connect this node to its successor node by
setting its outgoing and the successors ingoing."""
if isinstance(self, ConnectToExitNode) and not isinstance(successor, EntryOrExitNode):
return
self.outgoing.append(successor)
successor.ingoing.append(self)
def connect_predecessors(self, predecessors):
"""Connect all nodes in predecessors to this node."""
for n in predecessors:
self.ingoing.append(n)
n.outgoing.append(self)
def __str__(self):
"""Print the label of the node."""
return ''.join((' Label: ', self.label))
def __repr__(self):
"""Print a representation of the node."""
label = ' '.join(('Label: ', self.label))
line_number = 'Line number: ' + str(self.line_number)
outgoing = ''
ingoing = ''
if self.ingoing:
ingoing = ' '.join(('ingoing:\t', str([x.label for x in self.ingoing])))
else:
ingoing = ' '.join(('ingoing:\t', '[]'))
if self.outgoing:
outgoing = ' '.join(('outgoing:\t', str([x.label for x in self.outgoing])))
else:
outgoing = ' '.join(('outgoing:\t', '[]'))
return '\n' + '\n'.join((label, line_number, ingoing, outgoing))
class BreakNode(Node):
"""CFG Node that represents a Break statement."""
def __init__(self, ast_node, *, path):
super().__init__(
self.__class__.__name__,
ast_node,
path=path
)
class IfNode(Node):
"""CFG Node that represents an If statement."""
def __init__(self, test_node, ast_node, *, path):
label_visitor = LabelVisitor()
label_visitor.visit(test_node)
super().__init__(
'if ' + label_visitor.result + ':',
ast_node,
path=path
)
class TryNode(Node):
"""CFG Node that represents a Try statement."""
def __init__(self, ast_node, *, path):
super().__init__(
'try:',
ast_node,
path=path
)
class EntryOrExitNode(Node):
"""CFG Node that represents an Exit or an Entry node."""
def __init__(self, label):
super().__init__(label, None, line_number=None, path=None)
class RaiseNode(Node, ConnectToExitNode):
"""CFG Node that represents a Raise statement."""
def __init__(self, ast_node, *, path):
label_visitor = LabelVisitor()
label_visitor.visit(ast_node)
super().__init__(
label_visitor.result,
ast_node,
path=path
)
class AssignmentNode(Node):
"""CFG Node that represents an assignment."""
def __init__(self, label, left_hand_side, ast_node, right_hand_side_variables, *, line_number=None, path):
"""Create an Assignment node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
ast_node(_ast.Assign, _ast.AugAssign, _ast.Return or None)
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
"""
super().__init__(label, ast_node, line_number=line_number, path=path)
self.left_hand_side = left_hand_side
self.right_hand_side_variables = right_hand_side_variables
def __repr__(self):
output_string = super().__repr__()
output_string += '\n'
return ''.join((output_string,
'left_hand_side:\t', str(self.left_hand_side), '\n',
'right_hand_side_variables:\t', str(self.right_hand_side_variables)))
class TaintedNode(AssignmentNode):
"""CFG Node that represents a tainted node.
Only created in framework_adaptor.py and only used in `identify_triggers` of vulnerabilities.py
"""
pass
class RestoreNode(AssignmentNode):
"""Node used for handling restore nodes returning from function calls."""
def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path):
"""Create a Restore node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
"""
super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path)
class BBorBInode(AssignmentNode):
"""Node used for handling restore nodes returning from blackbox or builtin function calls."""
def __init__(self, label, left_hand_side, right_hand_side_variables, *, line_number, path, func_name):
"""Create a Restore node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
func_name(string): The string we will compare with the blackbox_mapping in vulnerabilities.py
"""
super().__init__(label, left_hand_side, None, right_hand_side_variables, line_number=line_number, path=path)
self.args = list()
self.inner_most_call = self
self.func_name = func_name
class AssignmentCallNode(AssignmentNode):
"""Node used for when a call happens inside of an assignment."""
def __init__(
self,
label,
left_hand_side,
ast_node,
right_hand_side_variables,
*,
line_number,
path,
call_node
):
"""Create an Assignment Call node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
ast_node
right_hand_side_variables(list[str]): A list of variables on the right hand side.
line_number(Optional[int]): The line of the expression the Node represents.
path(string): Current filename.
call_node(BBorBInode or RestoreNode): Used in connect_control_flow_node.
"""
super().__init__(
label,
left_hand_side,
ast_node,
right_hand_side_variables,
line_number=line_number,
path=path
)
self.call_node = call_node
self.blackbox = False
class ReturnNode(AssignmentNode, ConnectToExitNode):
"""CFG node that represents a return from a call."""
def __init__(
self,
label,
left_hand_side,
ast_node,
right_hand_side_variables,
*,
path
):
"""Create a return from a call node.
Args:
label(str): The label of the node, describing the expression it represents.
left_hand_side(str): The variable on the left hand side of the assignment. Used for analysis.
ast_node
right_hand_side_variables(list[str]): A list of variables on the right hand side.
path(string): Current filename.
"""
super().__init__(
label,
left_hand_side,
ast_node,
right_hand_side_variables,
line_number=ast_node.lineno,
path=path
)
"""Generates a list of CFGs from a path.
The module finds all python modules and generates an ast for them.
"""
import os
local_modules = list()
def get_directory_modules(directory):
"""Return a list containing tuples of
e.g. ('__init__', 'example/import_test_project/__init__.py')
"""
if local_modules and os.path.dirname(local_modules[0][1]) == directory:
return local_modules
if not os.path.isdir(directory):
# example/import_test_project/A.py -> example/import_test_project
directory = os.path.dirname(directory)
if directory == '':
return local_modules
for path in os.listdir(directory):
if is_python_file(path):
# A.py -> A
module_name = os.path.splitext(path)[0]
local_modules.append((module_name, os.path.join(directory, path)))
return local_modules
def get_modules(path):
"""Return a list containing tuples of
e.g. ('test_project.utils', 'example/test_project/utils.py')
"""
module_root = os.path.split(path)[1]
modules = list()
for root, directories, filenames in os.walk(path):
for filename in filenames:
if is_python_file(filename):
directory = os.path.dirname(os.path.realpath(os.path.join(root, filename))).split(module_root)[-1].replace(os.sep, '.')
directory = directory.replace('.', '', 1)
if directory:
modules.append(('.'.join((module_root, directory, filename.replace('.py', ''))), os.path.join(root, filename)))
else:
modules.append(('.'.join((module_root, filename.replace('.py', ''))), os.path.join(root, filename)))
return modules
def get_modules_and_packages(path):
"""Return a list containing tuples of
e.g. ('folder', 'example/test_project/folder', '.folder')
('test_project.utils', 'example/test_project/utils.py')
"""
module_root = os.path.split(path)[1]
modules = list()
for root, directories, filenames in os.walk(path):
for directory in directories:
if directory != '__pycache__':
full_path = os.path.join(root, directory)
relative_path = os.path.realpath(full_path).split(module_root)[-1].replace(os.sep, '.')
# Remove the dot in front to be consistent
modules.append((relative_path[1:], full_path, relative_path))
for filename in filenames:
if is_python_file(filename):
full_path = os.path.join(root, filename)
directory = os.path.dirname(os.path.realpath(full_path)).split(module_root)[-1].replace(os.sep, '.')
directory = directory.replace('.', '', 1)
if directory:
modules.append(('.'.join((module_root, directory, filename.replace('.py', ''))), full_path))
else:
modules.append(('.'.join((module_root, filename.replace('.py', ''))), full_path))
return modules
def is_python_file(path):
if os.path.splitext(path)[1] == '.py':
return True
return False
from .analysis_base import AnalysisBase
from .constraint_table import constraint_join
from .lattice import Lattice
from .node_types import AssignmentNode
class ReachingDefinitionsAnalysisBase(AnalysisBase):
"""Reaching definitions analysis rules implemented."""
def __init__(self, cfg):
super().__init__(cfg)
def join(self, cfg_node):
"""Joins all constraints of the ingoing nodes and returns them.
This represents the JOIN auxiliary definition from Schwartzbach."""
return constraint_join(cfg_node.ingoing)
def arrow(self, JOIN, _id):
"""Removes all previous assignments from JOIN that have the same left hand side.
This represents the arrow id definition from Schwartzbach."""
r = JOIN
for node in self.lattice.get_elements(JOIN):
if node.left_hand_side == _id:
r = r ^ self.lattice.el2bv[node]
return r
def fixpointmethod(self, cfg_node):
raise NotImplementedError()
def dep(self, q_1):
"""Represents the dep mapping from Schwartzbach."""
for node in q_1.outgoing:
yield node
def get_lattice_elements(cfg_nodes):
"""Returns all assignment nodes as they are the only lattice elements
in the reaching definitions analysis.
This is a static method which is overwritten from the base class."""
for node in cfg_nodes:
if isinstance(node, AssignmentNode):
yield node
def equal(self, value, other):
return value == other
def build_lattice(self, cfg):
self.lattice = Lattice(cfg.nodes, ReachingDefinitionsAnalysisBase)
from .constraint_table import constraint_table
from .node_types import AssignmentNode
from .reaching_definitions_base import ReachingDefinitionsAnalysisBase
class ReachingDefinitionsTaintAnalysis(ReachingDefinitionsAnalysisBase):
"""Reaching definitions analysis rules implemented."""
def fixpointmethod(self, cfg_node):
JOIN = self.join(cfg_node)
# Assignment check
if isinstance(cfg_node, AssignmentNode):
arrow_result = JOIN
# Reassignment check
if cfg_node.left_hand_side not in cfg_node.right_hand_side_variables:
# Get previous assignments of cfg_node.left_hand_side and remove them from JOIN
arrow_result = self.arrow(JOIN, cfg_node.left_hand_side)
arrow_result = arrow_result | self.lattice.el2bv[cfg_node]
constraint_table[cfg_node] = arrow_result
# Default case
else:
constraint_table[cfg_node] = JOIN
from .constraint_table import constraint_table
from .node_types import AssignmentNode
from .reaching_definitions_base import ReachingDefinitionsAnalysisBase
class ReachingDefinitionsAnalysis(ReachingDefinitionsAnalysisBase):
"""Reaching definitions analysis rules implemented."""
def fixpointmethod(self, cfg_node):
JOIN = self.join(cfg_node)
# Assignment check
if isinstance(cfg_node, AssignmentNode):
arrow_result = JOIN
# Get previous assignments of cfg_node.left_hand_side and remove them from JOIN
arrow_result = self.arrow(JOIN, cfg_node.left_hand_side)
arrow_result = arrow_result | self.lattice.el2bv[cfg_node]
constraint_table[cfg_node] = arrow_result
# Default case
else:
constraint_table[cfg_node] = JOIN
"""This modules runs PyT on a CSV file of git repos."""
import os
import shutil
import git
DEFAULT_CSV_PATH = 'flask_open_source_apps.csv'
class NoEntryPathError(Exception):
pass
class Repo:
"""Holder for a repo with git URL and
a path to where the analysis should start"""
def __init__(self, URL, path=None):
self.URL = URL.strip()
if path:
self.path = path.strip()
else:
self.path = None
self.directory = None
def clone(self):
"""Clone repo and update path to match the current one"""
r = self.URL.split('/')[-1].split('.')
if len(r) > 1:
self.directory = '.'.join(r[:-1])
else:
self.directory = r[0]
if self.directory not in os.listdir():
git.Git().clone(self.URL)
if self.path is None:
self._find_entry_path()
elif self.path[0] == '/':
self.path = self.path[1:]
self.path = os.path.join(self.directory, self.path)
else:
self.path = os.path.join(self.directory, self.path)
def _find_entry_path(self):
for root, dirs, files in os.walk(self.directory):
for f in files:
if f.endswith('.py'):
with open(os.path.join(root, f), 'r') as fd:
if 'app = Flask(__name__)' in fd.read():
self.path = os.path.join(root, f)
return
raise NoEntryPathError('No entry path found in repo {}.'
.format(self.URL))
def clean_up(self):
"""Deletes the repo"""
shutil.rmtree(self.directory)
def get_repos(csv_path):
"""Parses a CSV file containing repos."""
repos = list()
with open(csv_path, 'r') as fd:
for line in fd:
url, path = line.split(',')
repos.append(Repo(url, path))
return repos
def add_repo_to_file(path, repo):
try:
with open(path, 'a') as fd:
fd.write('{}{}, {}'
.format(os.linesep, repo.URL, repo.path))
except FileNotFoundError:
print('-csv handle not used and fallback path not found: {}'
.format(DEFAULT_CSV_PATH))
print('You need to specify the csv_path'
' by using the "-csv" handle.')
exit(1)
def add_repo_to_csv(csv_path, repo):
if csv_path is None:
add_repo_to_file(DEFAULT_CSV_PATH, repo)
else:
add_repo_to_file(csv_path, repo)
"""Contains a class that finds all names.
Used to find all variables on a right hand side(RHS) of assignment.
"""
import ast
class RHSVisitor(ast.NodeVisitor):
"""Visitor collecting all names."""
def __init__(self):
"""Initialize result as list."""
self.result = list()
def visit_Name(self, node):
self.result.append(node.id)
def visit_Call(self, node):
if node.args:
for arg in node.args:
self.visit(arg)
if node.keywords:
for keyword in node.keywords:
self.visit(keyword)
import os
from datetime import datetime
from .definition_chains import (
build_def_use_chain,
build_use_def_chain
)
from .formatters import text
from .lattice import Lattice
from .node_types import Node
database_file_name = 'db.sql'
nodes_table_name = 'nodes'
vulnerabilities_table_name = 'vulnerabilities'
def create_nodes_table():
with open(database_file_name, 'a') as fd:
fd.write('DROP TABLE IF EXISTS ' + nodes_table_name + '\n')
fd.write('CREATE TABLE ' + nodes_table_name + '(id int,label varchar(255),line_number int, path varchar(255));')
def create_vulnerabilities_table():
with open(database_file_name, 'a') as fd:
fd.write('DROP TABLE IF EXISTS ' + vulnerabilities_table_name + '\n')
fd.write('CREATE TABLE ' + vulnerabilities_table_name + '(id int, source varchar(255), source_word varchar(255), sink varchar(255), sink_word varchar(255));')
def quote(item):
if isinstance(item, Node):
item = item.label
return "'" + item.replace("'", "''") + "'"
def insert_vulnerability(vulnerability):
with open(database_file_name, 'a') as fd:
fd.write('\nINSERT INTO ' + vulnerabilities_table_name + '\n')
fd.write('VALUES (')
fd.write(quote(vulnerability.__dict__['source']) + ',')
fd.write(quote(vulnerability.__dict__['source_trigger_word']) + ',')
fd.write(quote(vulnerability.__dict__['sink']) + ',')
fd.write(quote(vulnerability.__dict__['sink_trigger_word']))
fd.write(');')
def insert_node(node):
with open(database_file_name, 'a') as fd:
fd.write('\nINSERT INTO ' + nodes_table_name + '\n')
fd.write('VALUES (')
fd.write("'" + node.__dict__['label'].replace("'", "''") + "'" + ',')
line_number = node.__dict__['line_number']
if line_number:
fd.write(str(line_number) + ',')
else:
fd.write('NULL,')
path = node.__dict__['path']
if path:
fd.write("'" + path.replace("'", "''") + "'")
else:
fd.write('NULL')
fd.write(');')
def create_database(cfg_list, vulnerabilities):
create_nodes_table()
for cfg in cfg_list:
for node in cfg.nodes:
insert_node(node)
create_vulnerabilities_table()
for vulnerability in vulnerabilities:
insert_vulnerability(vulnerability)
class Output():
filename_prefix = None
def __init__(self, title):
if Output.filename_prefix:
self.title = Output.filename_prefix + '_' + title
else:
self.title = title
def __enter__(self):
self.fd = open(self.title, 'w')
return self.fd
def __exit__(self, type, value, traceback):
self.fd.close()
def def_use_chain_to_file(cfg_list):
with Output('def-use_chain.pyt') as fd:
for i, cfg in enumerate(cfg_list):
fd.write('##### Def-use chain for CFG {} #####{}'
.format(i, os.linesep))
def_use = build_def_use_chain(cfg.nodes)
for k, v in def_use.items():
fd.write('Def: {} -> Use: [{}]{}'
.format(k.label,
', '.join([n.label for n in v]),
os.linesep))
def use_def_chain_to_file(cfg_list):
with Output('use-def_chain.pyt') as fd:
for i, cfg in enumerate(cfg_list):
fd.write('##### Use-def chain for CFG {} #####{}'
.format(i, os.linesep))
def_use = build_use_def_chain(cfg.nodes)
for k, v in def_use.items():
fd.write('Use: {} -> Def: [{}]{}'
.format(k.label,
', '.join([n[1].label for n in v]),
os.linesep))
def cfg_to_file(cfg_list):
with Output('control_flow_graph.pyt') as fd:
for i, cfg in enumerate(cfg_list):
fd.write('##### CFG {} #####{}'.format(i, os.linesep))
for i, node in enumerate(cfg.nodes):
fd.write('Node {}: {}{}'.format(i, node.label, os.linesep))
def verbose_cfg_to_file(cfg_list):
with Output('verbose_control_flow_graph.pyt') as fd:
for i, cfg in enumerate(cfg_list):
fd.write('##### CFG {} #####{}'.format(i, os.linesep))
for i, node in enumerate(cfg.nodes):
fd.write('Node {}: {}{}'.format(i, repr(node), os.linesep))
def lattice_to_file(cfg_list, analysis_type):
with Output('lattice.pyt') as fd:
for i, cfg in enumerate(cfg_list):
fd.write('##### Lattice for CFG {} #####{}'.format(i, os.linesep))
l = Lattice(cfg.nodes, analysis_type)
fd.write('# Elements to bitvector #{}'.format(os.linesep))
for k, v in l.el2bv.items():
fd.write('{} -> {}{}'.format(str(k), bin(v), os.linesep))
fd.write('# Bitvector to elements #{}'.format(os.linesep))
for k, v in l.el2bv.items():
fd.write('{} -> {}{}'.format(bin(v), str(k), os.linesep))
def vulnerabilities_to_file(vulnerabilities):
with Output('vulnerabilities.pyt') as fd:
text.report(vulnerabilities, fd)
def save_repo_scan(repo, entry_path, vulnerabilities, error=None):
with open('scan.pyt', 'a') as fd:
fd.write('{}{}'.format(repo.name, os.linesep))
fd.write('{}{}'.format(repo.url, os.linesep))
fd.write('Entry file: {}{}'.format(entry_path, os.linesep))
fd.write('Scanned: {}{}'.format(datetime.now(), os.linesep))
if vulnerabilities:
text.report(vulnerabilities, fd)
else:
fd.write('No vulnerabilities found.{}'.format(os.linesep))
if error:
fd.write('An Error occurred while scanning the repo: {}'
.format(str(error)))
fd.write(os.linesep)
fd.write(os.linesep)
import ast
import random
from collections import namedtuple
from .node_types import (
AssignmentCallNode,
BBorBInode,
BreakNode,
ControlFlowNode,
RestoreNode
)
CALL_IDENTIFIER = '~'
ConnectStatements = namedtuple(
'ConnectStatements',
(
'first_statement',
'last_statements',
'break_statements'
)
)
def _get_inner_most_function_call(call_node):
# Loop to inner most function call
# e.g. return scrypt.inner in `foo = scrypt.outer(scrypt.inner(image_name))`
old_call_node = None
while call_node != old_call_node:
old_call_node = call_node
if isinstance(call_node, BBorBInode):
call_node = call_node.inner_most_call
else:
try:
# e.g. save_2_blah, even when there is a save_3_blah
call_node = call_node.first_node
except AttributeError:
# No inner calls
# Possible improvement: Make new node for RestoreNode's made in process_function
# and make `self.inner_most_call = self`
# So that we can duck type and not catch an exception when there are no inner calls.
# This is what we do in BBorBInode
pass
return call_node
def _connect_control_flow_node(control_flow_node, next_node):
"""Connect a ControlFlowNode properly to the next_node."""
for last in control_flow_node.last_nodes:
if isinstance(next_node, ControlFlowNode):
last.connect(next_node.test) # connect to next if test case
elif isinstance(next_node, AssignmentCallNode):
call_node = next_node.call_node
inner_most_call_node = _get_inner_most_function_call(call_node)
last.connect(inner_most_call_node)
else:
last.connect(next_node)
def connect_nodes(nodes):
"""Connect the nodes in a list linearly."""
for n, next_node in zip(nodes, nodes[1:]):
if isinstance(n, ControlFlowNode):
_connect_control_flow_node(n, next_node)
elif isinstance(next_node, ControlFlowNode):
n.connect(next_node.test)
elif isinstance(next_node, RestoreNode):
continue
elif CALL_IDENTIFIER in next_node.label:
continue
else:
n.connect(next_node)
def _get_names(node, result):
"""Recursively finds all names."""
if isinstance(node, ast.Name):
return node.id + result
elif isinstance(node, ast.Subscript):
return result
else:
return _get_names(node.value, result + '.' + node.attr)
def extract_left_hand_side(target):
"""Extract the left hand side variable from a target.
Removes list indexes, stars and other left hand side elements.
"""
left_hand_side = _get_names(target, '')
left_hand_side.replace('*', '')
if '[' in left_hand_side:
index = left_hand_side.index('[')
left_hand_side = target[:index]
return left_hand_side
def get_first_node(
node,
node_not_to_step_past
):
"""
This is a super hacky way of getting the first node after a statement.
We do this because we visit a statement and keep on visiting and get something in return that is rarely the first node.
So we loop and loop backwards until we hit the statement or there is nothing to step back to.
"""
ingoing = None
i = 0
current_node = node
while current_node.ingoing:
# This is used because there may be multiple ingoing and loop will cause an infinite loop if we did [0]
i = random.randrange(len(current_node.ingoing))
# e.g. We don't want to step past the Except of an Except basic block
if current_node.ingoing[i] == node_not_to_step_past:
break
ingoing = current_node.ingoing
current_node = current_node.ingoing[i]
if ingoing:
return ingoing[i]
return current_node
def get_first_statement(node_or_tuple):
"""Find the first statement of the provided object.
Returns:
The first element in the tuple if it is a tuple.
The node if it is a node.
"""
if isinstance(node_or_tuple, tuple):
return node_or_tuple[0]
else:
return node_or_tuple
def get_last_statements(cfg_statements):
"""Retrieve the last statements from a cfg_statements list."""
if isinstance(cfg_statements[-1], ControlFlowNode):
return cfg_statements[-1].last_nodes
else:
return [cfg_statements[-1]]
def remove_breaks(last_statements):
"""Remove all break statements in last_statements."""
return [n for n in last_statements if not isinstance(n, BreakNode)]
import ast
import itertools
import os.path
from .alias_helper import (
as_alias_handler,
handle_aliases_in_init_files,
handle_fdid_aliases,
not_as_alias_handler,
retrieve_import_alias_mapping
)
from .ast_helper import (
generate_ast,
get_call_names_as_string
)
from .label_visitor import LabelVisitor
from .module_definitions import (
LocalModuleDefinition,
ModuleDefinition,
ModuleDefinitions
)
from .node_types import (
AssignmentNode,
AssignmentCallNode,
BBorBInode,
BreakNode,
ControlFlowNode,
EntryOrExitNode,
IfNode,
IgnoredNode,
Node,
RaiseNode,
ReturnNode,
TryNode
)
from .project_handler import get_directory_modules
from .right_hand_side_visitor import RHSVisitor
from .stmt_visitor_helper import (
CALL_IDENTIFIER,
ConnectStatements,
connect_nodes,
extract_left_hand_side,
get_first_node,
get_first_statement,
get_last_statements,
remove_breaks
)
from .vars_visitor import VarsVisitor
class StmtVisitor(ast.NodeVisitor):
def visit_Module(self, node):
return self.stmt_star_handler(node.body)
def stmt_star_handler(
self,
stmts,
prev_node_to_avoid=None
):
"""Handle stmt* expressions in an AST node.
Links all statements together in a list of statements, accounting for statements with multiple last nodes.
"""
break_nodes = list()
cfg_statements = list()
self.prev_nodes_to_avoid.append(prev_node_to_avoid)
self.last_control_flow_nodes.append(None)
first_node = None
node_not_to_step_past = self.nodes[-1]
for stmt in stmts:
node = self.visit(stmt)
if isinstance(node, ControlFlowNode) and not isinstance(node.test, TryNode):
self.last_control_flow_nodes.append(node.test)
else:
self.last_control_flow_nodes.append(None)
if isinstance(node, ControlFlowNode):
break_nodes.extend(node.break_statements)
elif isinstance(node, BreakNode):
break_nodes.append(node)
if not isinstance(node, IgnoredNode):
cfg_statements.append(node)
if not first_node:
if isinstance(node, ControlFlowNode):
first_node = node.test
else:
first_node = get_first_node(
node,
node_not_to_step_past
)
self.prev_nodes_to_avoid.pop()
self.last_control_flow_nodes.pop()
connect_nodes(cfg_statements)
if cfg_statements:
if first_node:
first_statement = first_node
else:
first_statement = get_first_statement(cfg_statements[0])
last_statements = get_last_statements(cfg_statements)
return ConnectStatements(
first_statement=first_statement,
last_statements=last_statements,
break_statements=break_nodes
)
else: # When body of module only contains ignored nodes
return IgnoredNode()
def get_parent_definitions(self):
parent_definitions = None
if len(self.module_definitions_stack) > 1:
parent_definitions = self.module_definitions_stack[-2]
return parent_definitions
def add_to_definitions(self, node):
local_definitions = self.module_definitions_stack[-1]
parent_definitions = self.get_parent_definitions()
if parent_definitions:
parent_qualified_name = '.'.join(
parent_definitions.classes +
[node.name]
)
parent_definition = ModuleDefinition(
parent_definitions,
parent_qualified_name,
local_definitions.module_name,
self.filenames[-1]
)
parent_definition.node = node
parent_definitions.append_if_local_or_in_imports(parent_definition)
local_qualified_name = '.'.join(local_definitions.classes +
[node.name])
local_definition = LocalModuleDefinition(
local_definitions,
local_qualified_name,
None,
self.filenames[-1]
)
local_definition.node = node
local_definitions.append_if_local_or_in_imports(local_definition)
self.function_names.append(node.name)
def visit_ClassDef(self, node):
self.add_to_definitions(node)
local_definitions = self.module_definitions_stack[-1]
local_definitions.classes.append(node.name)
parent_definitions = self.get_parent_definitions()
if parent_definitions:
parent_definitions.classes.append(node.name)
self.stmt_star_handler(node.body)
local_definitions.classes.pop()
if parent_definitions:
parent_definitions.classes.pop()
return IgnoredNode()
def visit_FunctionDef(self, node):
self.add_to_definitions(node)
return IgnoredNode()
def handle_or_else(self, orelse, test):
"""Handle the orelse part of an if or try node.
Args:
orelse(list[Node])
test(Node)
Returns:
The last nodes of the orelse branch.
"""
if isinstance(orelse[0], ast.If):
control_flow_node = self.visit(orelse[0])
# Prefix the if label with 'el'
control_flow_node.test.label = 'el' + control_flow_node.test.label
test.connect(control_flow_node.test)
return control_flow_node.last_nodes
else:
else_connect_statements = self.stmt_star_handler(
orelse,
prev_node_to_avoid=self.nodes[-1]
)
test.connect(else_connect_statements.first_statement)
return else_connect_statements.last_statements
def visit_If(self, node):
test = self.append_node(IfNode(
node.test,
node,
path=self.filenames[-1]
))
body_connect_stmts = self.stmt_star_handler(node.body)
if isinstance(body_connect_stmts, IgnoredNode):
body_connect_stmts = ConnectStatements(
first_statement=test,
last_statements=[],
break_statements=[]
)
test.connect(body_connect_stmts.first_statement)
if node.orelse:
orelse_last_nodes = self.handle_or_else(node.orelse, test)
body_connect_stmts.last_statements.extend(orelse_last_nodes)
else:
body_connect_stmts.last_statements.append(test) # if there is no orelse, test needs an edge to the next_node
last_statements = remove_breaks(body_connect_stmts.last_statements)
return ControlFlowNode(test, last_statements, break_statements=body_connect_stmts.break_statements)
def visit_Raise(self, node):
return self.append_node(RaiseNode(
node,
path=self.filenames[-1]
))
def visit_Return(self, node):
label = LabelVisitor()
label.visit(node)
try:
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node.value)
except AttributeError:
rhs_visitor.result = 'EmptyReturn'
this_function_name = self.function_return_stack[-1]
LHS = 'ret_' + this_function_name
if isinstance(node.value, ast.Call):
return_value_of_call = self.visit(node.value)
return_node = ReturnNode(
LHS + ' = ' + return_value_of_call.left_hand_side,
LHS,
node,
[return_value_of_call.left_hand_side],
path=self.filenames[-1]
)
return_value_of_call.connect(return_node)
self.nodes.append(return_node)
return return_node
return self.append_node(ReturnNode(
LHS + ' = ' + label.result,
LHS,
node,
rhs_visitor.result,
path=self.filenames[-1]
))
def handle_stmt_star_ignore_node(self, body, fallback_cfg_node):
try:
fallback_cfg_node.connect(body.first_statement)
except AttributeError:
body = ConnectStatements(
first_statement=[fallback_cfg_node],
last_statements=[fallback_cfg_node],
break_statements=[]
)
return body
def visit_Try(self, node):
try_node = self.append_node(TryNode(
node,
path=self.filenames[-1]
))
body = self.stmt_star_handler(node.body)
body = self.handle_stmt_star_ignore_node(body, try_node)
last_statements = list()
for handler in node.handlers:
try:
name = handler.type.id
except AttributeError:
name = ''
handler_node = self.append_node(Node(
'except ' + name + ':',
handler,
line_number=handler.lineno,
path=self.filenames[-1]
))
for body_node in body.last_statements:
body_node.connect(handler_node)
handler_body = self.stmt_star_handler(handler.body)
handler_body = self.handle_stmt_star_ignore_node(handler_body, handler_node)
last_statements.extend(handler_body.last_statements)
if node.orelse:
orelse_last_nodes = self.handle_or_else(node.orelse, body.last_statements[-1])
body.last_statements.extend(orelse_last_nodes)
if node.finalbody:
finalbody = self.stmt_star_handler(node.finalbody)
for last in last_statements:
last.connect(finalbody.first_statement)
for last in body.last_statements:
last.connect(finalbody.first_statement)
body.last_statements.extend(finalbody.last_statements)
last_statements.extend(remove_breaks(body.last_statements))
return ControlFlowNode(try_node, last_statements, break_statements=body.break_statements)
def assign_tuple_target(self, node, right_hand_side_variables):
new_assignment_nodes = list()
for i, target in enumerate(node.targets[0].elts):
value = node.value.elts[i]
label = LabelVisitor()
label.visit(target)
if isinstance(value, ast.Call):
new_ast_node = ast.Assign(target, value)
new_ast_node.lineno = node.lineno
new_assignment_nodes.append(self.assignment_call_node(label.result, new_ast_node))
else:
label.result += ' = '
label.visit(value)
new_assignment_nodes.append(self.append_node(AssignmentNode(
label.result,
extract_left_hand_side(target),
ast.Assign(target, value),
right_hand_side_variables,
line_number=node.lineno,
path=self.filenames[-1]
)))
connect_nodes(new_assignment_nodes)
return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node
def assign_multi_target(self, node, right_hand_side_variables):
new_assignment_nodes = list()
for target in node.targets:
label = LabelVisitor()
label.visit(target)
left_hand_side = label.result
label.result += ' = '
label.visit(node.value)
new_assignment_nodes.append(self.append_node(AssignmentNode(
label.result,
left_hand_side,
ast.Assign(target, node.value),
right_hand_side_variables,
line_number=node.lineno,
path=self.filenames[-1]
)))
connect_nodes(new_assignment_nodes)
return ControlFlowNode(new_assignment_nodes[0], [new_assignment_nodes[-1]], []) # return the last added node
def visit_Assign(self, node):
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node.value)
if isinstance(node.targets[0], ast.Tuple): # x,y = [1,2]
if isinstance(node.value, ast.Tuple):
return self.assign_tuple_target(node, rhs_visitor.result)
elif isinstance(node.value, ast.Call):
call = None
for element in node.targets[0].elts:
label = LabelVisitor()
label.visit(element)
call = self.assignment_call_node(label.result, node)
return call
else:
label = LabelVisitor()
label.visit(node)
print('Assignment not properly handled.',
'Could result in not finding a vulnerability.',
'Assignment:', label.result)
return self.append_node(AssignmentNode(
label.result,
label.result,
node,
rhs_visitor.result,
path=self.filenames[-1]
))
elif len(node.targets) > 1: # x = y = 3
return self.assign_multi_target(node, rhs_visitor.result)
else:
if isinstance(node.value, ast.Call): # x = call()
label = LabelVisitor()
label.visit(node.targets[0])
return self.assignment_call_node(label.result, node)
else: # x = 4
label = LabelVisitor()
label.visit(node)
return self.append_node(AssignmentNode(
label.result,
extract_left_hand_side(node.targets[0]),
node,
rhs_visitor.result,
path=self.filenames[-1]
))
def assignment_call_node(self, left_hand_label, ast_node):
"""Handle assignments that contain a function call on its right side."""
self.undecided = True # Used for handling functions in assignments
call = self.visit(ast_node.value)
call_label = call.left_hand_side
if isinstance(call, BBorBInode):
# Necessary to know e.g.
# `image_name = image_name.replace('..', '')`
# is a reassignment.
vars_visitor = VarsVisitor()
vars_visitor.visit(ast_node.value)
call.right_hand_side_variables.extend(vars_visitor.result)
call_assignment = AssignmentCallNode(
left_hand_label + ' = ' + call_label,
left_hand_label,
ast_node,
[call.left_hand_side],
line_number=ast_node.lineno,
path=self.filenames[-1],
call_node=call
)
call.connect(call_assignment)
self.nodes.append(call_assignment)
self.undecided = False
return call_assignment
def visit_AugAssign(self, node):
label = LabelVisitor()
label.visit(node)
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node.value)
return self.append_node(AssignmentNode(
label.result,
extract_left_hand_side(node.target),
node,
rhs_visitor.result,
path=self.filenames[-1]
))
def loop_node_skeleton(self, test, node):
"""Common handling of looped structures, while and for."""
body_connect_stmts = self.stmt_star_handler(
node.body,
prev_node_to_avoid=self.nodes[-1]
)
test.connect(body_connect_stmts.first_statement)
test.connect_predecessors(body_connect_stmts.last_statements)
# last_nodes is used for making connections to the next node in the parent node
# this is handled in stmt_star_handler
last_nodes = list()
last_nodes.extend(body_connect_stmts.break_statements)
if node.orelse:
orelse_connect_stmts = self.stmt_star_handler(
node.orelse,
prev_node_to_avoid=self.nodes[-1]
)
test.connect(orelse_connect_stmts.first_statement)
last_nodes.extend(orelse_connect_stmts.last_statements)
else:
last_nodes.append(test) # if there is no orelse, test needs an edge to the next_node
return ControlFlowNode(test, last_nodes, list())
def visit_For(self, node):
self.undecided = False
iterator_label = LabelVisitor()
iterator_label.visit(node.iter)
target_label = LabelVisitor()
target_label.visit(node.target)
for_node = self.append_node(Node(
"for " + target_label.result + " in " + iterator_label.result + ':',
node,
path=self.filenames[-1]
))
if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names:
last_node = self.visit(node.iter)
last_node.connect(for_node)
return self.loop_node_skeleton(for_node, node)
def visit_While(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.test)
test = self.append_node(Node(
'while ' + label_visitor.result + ':',
node,
path=self.filenames[-1]
))
return self.loop_node_skeleton(test, node)
def add_blackbox_or_builtin_call(self, node, blackbox):
"""Processes a blackbox or builtin function when it is called.
Nothing gets assigned to ret_func_foo in the builtin/blackbox case.
Increments self.function_call_index each time it is called, we can refer to it as N in the comments.
Create e.g. ~call_1 = ret_func_foo RestoreNode.
Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument.
Visit the arguments if they're calls. (save_def_args_in_temp)
I do not think I care about this one actually -- Create e.g. def_arg1 = temp_N_def_arg1 for each argument.
(create_local_scope_from_def_args)
Add RestoreNode to the end of the Nodes.
Args:
node(ast.Call) : The node that calls the definition.
blackbox(bool): Whether or not it is a builtin or blackbox call.
Returns:
call_node(BBorBInode): The call node.
"""
self.function_call_index += 1
saved_function_call_index = self.function_call_index
self.undecided = False
call_label = LabelVisitor()
call_label.visit(node)
index = call_label.result.find('(')
# Create e.g. ~call_1 = ret_func_foo
LHS = CALL_IDENTIFIER + 'call_' + str(saved_function_call_index)
RHS = 'ret_' + call_label.result[:index] + '('
call_node = BBorBInode(
label='',
left_hand_side=LHS,
right_hand_side_variables=[],
line_number=node.lineno,
path=self.filenames[-1],
func_name=call_label.result[:index]
)
visual_args = list()
rhs_vars = list()
last_return_value_of_nested_call = None
for arg in itertools.chain(node.args, node.keywords):
if isinstance(arg, ast.Call):
return_value_of_nested_call = self.visit(arg)
if last_return_value_of_nested_call:
# connect inner to other_inner in e.g.
# `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))`
# I should probably loop to the inner most call of other_inner here.
try:
last_return_value_of_nested_call.connect(return_value_of_nested_call.first_node)
except AttributeError:
last_return_value_of_nested_call.connect(return_value_of_nested_call)
else:
# I should only set this once per loop, inner in e.g.
# `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))`
# (inner_most_call is used when predecessor is a ControlFlowNode in connect_control_flow_node)
call_node.inner_most_call = return_value_of_nested_call
last_return_value_of_nested_call = return_value_of_nested_call
visual_args.append(return_value_of_nested_call.left_hand_side)
rhs_vars.append(return_value_of_nested_call.left_hand_side)
else:
label = LabelVisitor()
label.visit(arg)
visual_args.append(label.result)
vv = VarsVisitor()
vv.visit(arg)
rhs_vars.extend(vv.result)
if last_return_value_of_nested_call:
# connect other_inner to outer in e.g.
# `scrypt.outer(scrypt.inner(image_name), scrypt.other_inner(image_name))`
last_return_value_of_nested_call.connect(call_node)
if len(visual_args) > 0:
for arg in visual_args:
RHS = RHS + arg + ", "
# Replace the last ", " with a )
RHS = RHS[:len(RHS) - 2] + ')'
else:
RHS = RHS + ')'
call_node.label = LHS + " = " + RHS
call_node.right_hand_side_variables = rhs_vars
# Used in get_sink_args, not using right_hand_side_variables because it is extended in assignment_call_node
rhs_visitor = RHSVisitor()
rhs_visitor.visit(node)
call_node.args = rhs_visitor.result
if blackbox:
self.blackbox_assignments.add(call_node)
self.connect_if_allowed(self.nodes[-1], call_node)
self.nodes.append(call_node)
return call_node
def visit_With(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.items[0])
with_node = self.append_node(Node(
label_visitor.result,
node,
path=self.filenames[-1]
))
connect_statements = self.stmt_star_handler(node.body)
with_node.connect(connect_statements.first_statement)
return ControlFlowNode(
with_node,
connect_statements.last_statements,
connect_statements.break_statements
)
def visit_Break(self, node):
return self.append_node(BreakNode(
node,
path=self.filenames[-1]
))
def visit_Delete(self, node):
labelVisitor = LabelVisitor()
for expr in node.targets:
labelVisitor.visit(expr)
return self.append_node(Node(
'del ' + labelVisitor.result,
node,
path=self.filenames[-1]
))
def visit_Assert(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.test)
return self.append_node(Node(
label_visitor.result,
node,
path=self.filenames[-1]
))
def visit_Continue(self, node):
return self.visit_miscelleaneous_node(
node,
custom_label='continue'
)
def visit_Global(self, node):
return self.visit_miscelleaneous_node(
node
)
def visit_Pass(self, node):
return self.visit_miscelleaneous_node(
node,
custom_label='pass'
)
def visit_miscelleaneous_node(
self,
node,
custom_label=None
):
if custom_label:
label = custom_label
else:
label_visitor = LabelVisitor()
label_visitor.visit(node)
label = label_visitor.result
return self.append_node(Node(
label,
node,
path=self.filenames[-1]
))
def visit_Expr(self, node):
return self.visit(node.value)
def append_node(self, node):
"""Append a node to the CFG and return it."""
self.nodes.append(node)
return node
def add_module(
self,
module,
module_or_package_name,
local_names,
import_alias_mapping,
is_init=False,
from_from=False,
from_fdid=False
):
"""
Returns:
The ExitNode that gets attached to the CFG of the class.
"""
module_path = module[1]
parent_definitions = self.module_definitions_stack[-1]
# The only place the import_alias_mapping is updated
parent_definitions.import_alias_mapping.update(import_alias_mapping)
parent_definitions.import_names = local_names
new_module_definitions = ModuleDefinitions(local_names, module_or_package_name)
new_module_definitions.is_init = is_init
self.module_definitions_stack.append(new_module_definitions)
# Analyse the file
self.filenames.append(module_path)
self.local_modules = get_directory_modules(module_path)
tree = generate_ast(module_path)
# module[0] is None during e.g. "from . import foo", so we must str()
self.nodes.append(EntryOrExitNode('Module Entry ' + str(module[0])))
self.visit(tree)
exit_node = self.append_node(EntryOrExitNode('Module Exit ' + str(module[0])))
# Done analysing, pop the module off
self.module_definitions_stack.pop()
self.filenames.pop()
if new_module_definitions.is_init:
for def_ in new_module_definitions.definitions:
module_def_alias = handle_aliases_in_init_files(
def_.name,
new_module_definitions.import_alias_mapping
)
parent_def_alias = handle_aliases_in_init_files(
def_.name,
parent_definitions.import_alias_mapping
)
# They should never both be set
assert not (module_def_alias and parent_def_alias)
def_name = def_.name
if parent_def_alias:
def_name = parent_def_alias
if module_def_alias:
def_name = module_def_alias
local_definitions = self.module_definitions_stack[-1]
if local_definitions != parent_definitions:
raise
if not isinstance(module_or_package_name, str):
module_or_package_name = module_or_package_name.name
if module_or_package_name:
if from_from:
qualified_name = def_name
if from_fdid:
alias = handle_fdid_aliases(module_or_package_name, import_alias_mapping)
if alias:
module_or_package_name = alias
parent_definition = ModuleDefinition(
parent_definitions,
qualified_name,
module_or_package_name,
self.filenames[-1]
)
else:
parent_definition = ModuleDefinition(
parent_definitions,
qualified_name,
None,
self.filenames[-1]
)
else:
qualified_name = module_or_package_name + '.' + def_name
parent_definition = ModuleDefinition(
parent_definitions,
qualified_name,
parent_definitions.module_name,
self.filenames[-1]
)
parent_definition.node = def_.node
parent_definitions.definitions.append(parent_definition)
else:
parent_definition = ModuleDefinition(
parent_definitions,
def_name,
parent_definitions.module_name,
self.filenames[-1]
)
parent_definition.node = def_.node
parent_definitions.definitions.append(parent_definition)
return exit_node
def from_directory_import(
self,
module,
real_names,
local_names,
import_alias_mapping,
skip_init=False
):
"""
Directories don't need to be packages.
"""
module_path = module[1]
init_file_location = os.path.join(module_path, '__init__.py')
init_exists = os.path.isfile(init_file_location)
if init_exists and not skip_init:
package_name = os.path.split(module_path)[1]
return self.add_module(
(module[0], init_file_location),
package_name,
local_names,
import_alias_mapping,
is_init=True,
from_from=True
)
for real_name in real_names:
full_name = os.path.join(module_path, real_name)
if os.path.isdir(full_name):
new_init_file_location = os.path.join(full_name, '__init__.py')
if os.path.isfile(new_init_file_location):
self.add_module(
(real_name, new_init_file_location),
real_name,
local_names,
import_alias_mapping,
is_init=True,
from_from=True,
from_fdid=True
)
else:
raise Exception('from anything import directory needs an __init__.py file in directory')
else:
file_module = (real_name, full_name + '.py')
self.add_module(
file_module,
real_name,
local_names,
import_alias_mapping,
from_from=True
)
return IgnoredNode()
def import_package(self, module, module_name, local_name, import_alias_mapping):
module_path = module[1]
init_file_location = os.path.join(module_path, '__init__.py')
init_exists = os.path.isfile(init_file_location)
if init_exists:
return self.add_module(
(module[0], init_file_location),
module_name,
local_name,
import_alias_mapping,
is_init=True
)
else:
raise Exception('import directory needs an __init__.py file')
def handle_relative_import(self, node):
"""
from A means node.level == 0
from . import B means node.level == 1
from .A means node.level == 1
"""
no_file = os.path.abspath(os.path.join(self.filenames[-1], os.pardir))
skip_init = False
if node.level == 1:
# Same directory as current file
if node.module:
name_with_dir = os.path.join(no_file, node.module.replace('.', '/'))
if not os.path.isdir(name_with_dir):
name_with_dir = name_with_dir + '.py'
# e.g. from . import X
else:
name_with_dir = no_file
# We do not want to analyse the init file of the current directory
skip_init = True
else:
parent = os.path.abspath(os.path.join(no_file, os.pardir))
if node.level > 2:
# Perform extra `cd ..` however many times
for _ in range(0, node.level - 2):
parent = os.path.abspath(os.path.join(parent, os.pardir))
if node.module:
name_with_dir = os.path.join(parent, node.module.replace('.', '/'))
if not os.path.isdir(name_with_dir):
name_with_dir = name_with_dir + '.py'
# e.g. from .. import X
else:
name_with_dir = parent
# Is it a file?
if name_with_dir.endswith('.py'):
return self.add_module(
(node.module, name_with_dir),
None,
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
from_from=True
)
return self.from_directory_import(
(node.module, name_with_dir),
not_as_alias_handler(node.names),
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
skip_init=skip_init
)
def visit_Import(self, node):
for name in node.names:
for module in self.local_modules:
if name.name == module[0]:
if os.path.isdir(module[1]):
return self.import_package(
module,
name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
return self.add_module(
module,
name.name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
for module in self.project_modules:
if name.name == module[0]:
if os.path.isdir(module[1]):
return self.import_package(
module,
name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
return self.add_module(
module,
name.name,
name.asname,
retrieve_import_alias_mapping(node.names)
)
return IgnoredNode()
def visit_ImportFrom(self, node):
# Is it relative?
if node.level > 0:
return self.handle_relative_import(node)
else:
for module in self.local_modules:
if node.module == module[0]:
if os.path.isdir(module[1]):
return self.from_directory_import(
module,
not_as_alias_handler(node.names),
as_alias_handler(node.names)
)
return self.add_module(
module,
None,
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
from_from=True
)
for module in self.project_modules:
name = module[0]
if node.module == name:
if os.path.isdir(module[1]):
return self.from_directory_import(
module,
not_as_alias_handler(node.names),
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names)
)
return self.add_module(
module,
None,
as_alias_handler(node.names),
retrieve_import_alias_mapping(node.names),
from_from=True
)
return IgnoredNode()
import os
from collections import namedtuple
SANITISER_SEPARATOR = '->'
SOURCES_KEYWORD = 'sources:'
SINKS_KEYWORD = 'sinks:'
Definitions = namedtuple(
'Definitions',
(
'sources',
'sinks'
)
)
def parse_section(iterator):
"""Parse a section of a file. Stops at empty line.
Args:
iterator(File): file descriptor pointing at a definition file.
Returns:
Iterator of all definitions in the section.
"""
try:
line = next(iterator).rstrip()
while line:
if line.rstrip():
if SANITISER_SEPARATOR in line:
line = line.split(SANITISER_SEPARATOR)
sink = line[0].rstrip()
sanitisers = list(map(str.strip, line[1].split(',')))
yield (sink, sanitisers)
else:
yield (line, list())
line = next(iterator).rstrip()
except StopIteration:
return
def parse(trigger_word_file):
"""Parse the file for source and sink definitions.
Returns:
A definitions tuple with sources and sinks.
"""
sources = list()
sinks = list()
with open(trigger_word_file, 'r') as fd:
for line in fd:
line = line.rstrip()
if line == SOURCES_KEYWORD:
sources = list(parse_section(fd))
elif line == SINKS_KEYWORD:
sinks = list(parse_section(fd))
return Definitions(sources, sinks)
import ast
import itertools
from .ast_helper import get_call_names
class VarsVisitor(ast.NodeVisitor):
def __init__(self):
self.result = list()
def visit_Name(self, node):
self.result.append(node.id)
def visit_BoolOp(self, node):
for v in node.values:
self.visit(v)
def visit_BinOp(self, node):
self.visit(node.left)
self.visit(node.right)
def visit_UnaryOp(self, node):
self.visit(node.operand)
def visit_Lambda(self, node):
self.visit(node.body)
def visit_IfExpr(self, node):
self.visit(node.test)
self.visit(node.body)
self.visit(node.orelse)
def visit_Dict(self, node):
for k in node.keys:
self.visit(k)
for v in node.values:
self.visit(v)
def visit_Set(self, node):
for e in node.elts:
self.visit(e)
def comprehension(self, node):
self.visit(node.target)
self.visit(node.iter)
for c in node.ifs:
self.visit(c)
def visit_ListComp(self, node):
self.visit(node.elt)
for gen in node.generators:
self.comprehension(gen)
def visit_SetComp(self, node):
self.visit(node.elt)
for gen in node.generators:
self.comprehension(gen)
def visit_DictComp(self, node):
self.visit(node.key)
self.visit(node.value)
for gen in node.generators:
self.comprehension(gen)
def visit_GeneratorComp(self, node):
self.visit(node.elt)
for gen in node.generators:
self.comprehension(gen)
def visit_Await(self, node):
self.visit(node.value)
def visit_Yield(self, node):
if node.value:
self.visit(node.value)
def visit_YieldFrom(self, node):
self.visit(node.value)
def visit_Compare(self, node):
self.visit(node.left)
for c in node.comparators:
self.visit(c)
def visit_Call(self, node):
# This will not visit Flask in Flask(__name__) but it will visit request in `request.args.get()
if not isinstance(node.func, ast.Name):
self.visit(node.func)
for arg in itertools.chain(node.args, node.keywords):
if isinstance(arg, ast.Call):
if isinstance(arg.func, ast.Name):
# We can't just visit because we need to add 'ret_'
self.result.append('ret_' + arg.func.id)
elif isinstance(arg.func, ast.Attribute):
# e.g. html.replace('{{ param }}', param)
# func.attr is replace
# func.value.id is html
# We want replace
self.result.append('ret_' + arg.func.attr)
else:
# Deal with it when we have code that triggers it.
raise
else:
self.visit(arg)
def visit_Attribute(self, node):
if not isinstance(node.value, ast.Name):
self.visit(node.value)
else:
self.result.append(node.value.id)
def slicev(self, node):
if isinstance(node, ast.Slice):
if node.lower:
self.visit(node.lower)
if node.upper:
self.visit(node.upper)
if node.step:
self.visit(node.step)
elif isinstance(node, ast.ExtSlice):
if node.dims:
for d in node.dims:
self.visit(d)
else:
self.visit(node.value)
def visit_Subscript(self, node):
if isinstance(node.value, ast.Attribute):
self.result.append(list(get_call_names(node.value))[0])
self.visit(node.value)
self.slicev(node.slice)
def visit_Starred(self, node):
self.visit(node.value)
def visit_List(self, node):
for el in node.elts:
self.visit(el)
def visit_Tuple(self, node):
for el in node.elts:
self.visit(el)
"""Module for finding vulnerabilities based on a definitions file."""
import ast
import json
from collections import namedtuple
from .argument_helpers import UImode
from .definition_chains import build_def_use_chain
from .lattice import Lattice
from .node_types import (
AssignmentNode,
BBorBInode,
IfNode,
TaintedNode
)
from .right_hand_side_visitor import RHSVisitor
from .trigger_definitions_parser import parse
from .vars_visitor import VarsVisitor
from .vulnerability_helper import (
vuln_factory,
VulnerabilityType
)
Sanitiser = namedtuple(
'Sanitiser',
(
'trigger_word',
'cfg_node'
)
)
Triggers = namedtuple(
'Triggers',
(
'sources',
'sinks',
'sanitiser_dict'
)
)
class TriggerNode():
def __init__(self, trigger_word, sanitisers, cfg_node, secondary_nodes=[]):
self.trigger_word = trigger_word
self.sanitisers = sanitisers
self.cfg_node = cfg_node
self.secondary_nodes = secondary_nodes
def append(self, cfg_node):
if not cfg_node == self.cfg_node:
if self.secondary_nodes and cfg_node not in self.secondary_nodes:
self.secondary_nodes.append(cfg_node)
elif not self.secondary_nodes:
self.secondary_nodes = [cfg_node]
def __repr__(self):
output = 'TriggerNode('
if self.trigger_word:
output = '{} trigger_word is {}, '.format(
output,
self.trigger_word
)
return (
output +
'sanitisers are {}, '.format(self.sanitisers) +
'cfg_node is {})\n'.format(self.cfg_node)
)
def identify_triggers(
cfg,
sources,
sinks,
lattice
):
"""Identify sources, sinks and sanitisers in a CFG.
Args:
cfg(CFG): CFG to find sources, sinks and sanitisers in.
sources(tuple): list of sources, a source is a (source, sanitiser) tuple.
sinks(tuple): list of sources, a sink is a (sink, sanitiser) tuple.
Returns:
Triggers tuple with sink and source nodes and a sanitiser node dict.
"""
assignment_nodes = filter_cfg_nodes(cfg, AssignmentNode)
tainted_nodes = filter_cfg_nodes(cfg, TaintedNode)
tainted_trigger_nodes = [TriggerNode('Framework function URL parameter', None,
node) for node in tainted_nodes]
sources_in_file = find_triggers(assignment_nodes, sources)
sources_in_file.extend(tainted_trigger_nodes)
find_secondary_sources(assignment_nodes, sources_in_file, lattice)
sinks_in_file = find_triggers(cfg.nodes, sinks)
sanitiser_node_dict = build_sanitiser_node_dict(cfg, sinks_in_file)
return Triggers(sources_in_file, sinks_in_file, sanitiser_node_dict)
def filter_cfg_nodes(
cfg,
cfg_node_type
):
return [node for node in cfg.nodes if isinstance(node, cfg_node_type)]
def find_secondary_sources(
assignment_nodes,
sources,
lattice
):
"""
Sets the secondary_nodes attribute of each source in the sources list.
Args:
assignment_nodes([AssignmentNode])
sources([tuple])
lattice(Lattice): the lattice we're analysing.
"""
for source in sources:
source.secondary_nodes = find_assignments(assignment_nodes, source, lattice)
def find_assignments(
assignment_nodes,
source,
lattice
):
old = list()
# propagate reassignments of the source node
new = [source.cfg_node]
while new != old:
update_assignments(new, assignment_nodes, source.cfg_node, lattice)
old = new
# remove source node from result
del new[0]
return new
def update_assignments(
assignment_list,
assignment_nodes,
source,
lattice
):
for node in assignment_nodes:
for other in assignment_list:
if node not in assignment_list and lattice.in_constraint(other, node):
append_node_if_reassigned(assignment_list, other, node)
def append_node_if_reassigned(
assignment_list,
secondary,
node
):
if (
secondary.left_hand_side in node.right_hand_side_variables or
secondary.left_hand_side == node.left_hand_side
):
assignment_list.append(node)
def find_triggers(
nodes,
trigger_words
):
"""Find triggers from the trigger_word_list in the nodes.
Args:
nodes(list[Node]): the nodes to find triggers in.
trigger_word_list(list[string]): list of trigger words to look for.
Returns:
List of found TriggerNodes
"""
trigger_nodes = list()
for node in nodes:
trigger_nodes.extend(iter(label_contains(node, trigger_words)))
return trigger_nodes
def label_contains(
node,
trigger_words
):
"""Determine if node contains any of the trigger_words provided.
Args:
node(Node): CFG node to check.
trigger_words(list[string]): list of trigger words to look for.
Returns:
Iterable of TriggerNodes found. Can be multiple because multiple
trigger_words can be in one node.
"""
for trigger_word_tuple in trigger_words:
if trigger_word_tuple[0] in node.label:
trigger_word = trigger_word_tuple[0]
sanitisers = trigger_word_tuple[1]
yield TriggerNode(trigger_word, sanitisers, node)
def build_sanitiser_node_dict(
cfg,
sinks_in_file
):
"""Build a dict of string -> TriggerNode pairs, where the string
is the sanitiser and the TriggerNode is a TriggerNode of the sanitiser.
Args:
cfg(CFG): cfg to traverse.
sinks_in_file(list[TriggerNode]): list of TriggerNodes containing
the sinks in the file.
Returns:
A string -> TriggerNode dict.
"""
sanitisers = list()
for sink in sinks_in_file:
sanitisers.extend(sink.sanitisers)
sanitisers_in_file = list()
for sanitiser in sanitisers:
for cfg_node in cfg.nodes:
if sanitiser in cfg_node.label:
sanitisers_in_file.append(Sanitiser(sanitiser, cfg_node))
sanitiser_node_dict = dict()
for sanitiser in sanitisers:
sanitiser_node_dict[sanitiser] = list(find_sanitiser_nodes(
sanitiser,
sanitisers_in_file
))
return sanitiser_node_dict
def find_sanitiser_nodes(
sanitiser,
sanitisers_in_file
):
"""Find nodes containing a particular sanitiser.
Args:
sanitiser(string): sanitiser to look for.
sanitisers_in_file(list[Node]): list of CFG nodes with the sanitiser.
Returns:
Iterable of sanitiser nodes.
"""
for sanitiser_tuple in sanitisers_in_file:
if sanitiser == sanitiser_tuple.trigger_word:
yield sanitiser_tuple.cfg_node
def get_sink_args(cfg_node):
if isinstance(cfg_node.ast_node, ast.Call):
rhs_visitor = RHSVisitor()
rhs_visitor.visit(cfg_node.ast_node)
return rhs_visitor.result
elif isinstance(cfg_node.ast_node, ast.Assign):
return cfg_node.right_hand_side_variables
elif isinstance(cfg_node, BBorBInode):
return cfg_node.args
else:
vv = VarsVisitor()
vv.visit(cfg_node.ast_node)
return vv.result
def get_vulnerability_chains(
current_node,
sink,
def_use,
chain=[]
):
"""Traverses the def-use graph to find all paths from source to sink that cause a vulnerability.
Args:
current_node()
sink()
def_use(dict):
chain(list(Node)): A path of nodes between source and sink.
"""
for use in def_use[current_node]:
if use == sink:
yield chain
else:
vuln_chain = list(chain)
vuln_chain.append(use)
yield from get_vulnerability_chains(
use,
sink,
def_use,
vuln_chain
)
def how_vulnerable(
chain,
blackbox_mapping,
sanitiser_nodes,
potential_sanitiser,
blackbox_assignments,
ui_mode,
vuln_deets
):
"""Iterates through the chain of nodes and checks the blackbox nodes against the blackbox mapping and sanitiser dictionary.
Note: potential_sanitiser is the only hack here, it is because we do not take p-use's into account yet.
e.g. we can only say potentially instead of definitely sanitised in the path_traversal_sanitised_2.py test.
Args:
chain(list(Node)): A path of nodes between source and sink.
blackbox_mapping(dict): A map of blackbox functions containing whether or not they propagate taint.
sanitiser_nodes(set): A set of nodes that are sanitisers for the sink.
potential_sanitiser(Node): An if or elif node that can potentially cause sanitisation.
blackbox_assignments(set[AssignmentNode]): set of blackbox assignments, includes the ReturnNode's of BBorBInode's.
ui_mode(UImode): determines if we interact with the user when we don't already have a blackbox mapping available.
vuln_deets(dict): vulnerability details.
Returns:
A VulnerabilityType depending on how vulnerable the chain is.
"""
for i, current_node in enumerate(chain):
if current_node in sanitiser_nodes:
vuln_deets['sanitiser'] = current_node
vuln_deets['confident'] = True
return VulnerabilityType.SANITISED
if isinstance(current_node, BBorBInode):
if current_node.func_name in blackbox_mapping['propagates']:
continue
elif current_node.func_name in blackbox_mapping['does_not_propagate']:
return VulnerabilityType.FALSE
elif ui_mode == UImode.INTERACTIVE:
user_says = input(
'Is the return value of {} with tainted argument "{}" vulnerable? (Y/n)'.format(
current_node.label,
chain[i - 1].left_hand_side
)
).lower()
if user_says.startswith('n'):
blackbox_mapping['does_not_propagate'].append(current_node.func_name)
return VulnerabilityType.FALSE
blackbox_mapping['propagates'].append(current_node.func_name)
else:
vuln_deets['unknown_assignment'] = current_node
return VulnerabilityType.UNKNOWN
if potential_sanitiser:
vuln_deets['sanitiser'] = potential_sanitiser
vuln_deets['confident'] = False
return VulnerabilityType.SANITISED
return VulnerabilityType.TRUE
def get_tainted_node_in_sink_args(
sink_args,
nodes_in_constaint
):
if not sink_args:
return None
# Starts with the node closest to the sink
for node in nodes_in_constaint:
if node.left_hand_side in sink_args:
return node
def get_vulnerability(
source,
sink,
triggers,
lattice,
cfg,
ui_mode,
blackbox_mapping
):
"""Get vulnerability between source and sink if it exists.
Uses triggers to find sanitisers.
Note: When a secondary node is in_constraint with the sink
but not the source, the secondary is a save_N_LHS
node made in process_function in expr_visitor.
Args:
source(TriggerNode): TriggerNode of the source.
sink(TriggerNode): TriggerNode of the sink.
triggers(Triggers): Triggers of the CFG.
lattice(Lattice): the lattice we're analysing.
cfg(CFG): .blackbox_assignments used in is_unknown, .nodes used in build_def_use_chain
ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all.
blackbox_mapping(dict): A map of blackbox functions containing whether or not they propagate taint.
Returns:
A Vulnerability if it exists, else None
"""
nodes_in_constaint = [secondary for secondary in reversed(source.secondary_nodes)
if lattice.in_constraint(secondary,
sink.cfg_node)]
nodes_in_constaint.append(source.cfg_node)
sink_args = get_sink_args(sink.cfg_node)
tainted_node_in_sink_arg = get_tainted_node_in_sink_args(
sink_args,
nodes_in_constaint
)
if tainted_node_in_sink_arg:
vuln_deets = {
'source': source.cfg_node,
'source_trigger_word': source.trigger_word,
'sink': sink.cfg_node,
'sink_trigger_word': sink.trigger_word,
'reassignment_nodes': source.secondary_nodes
}
sanitiser_nodes = set()
potential_sanitiser = None
if sink.sanitisers:
for sanitiser in sink.sanitisers:
for cfg_node in triggers.sanitiser_dict[sanitiser]:
if isinstance(cfg_node, AssignmentNode):
sanitiser_nodes.add(cfg_node)
elif isinstance(cfg_node, IfNode):
potential_sanitiser = cfg_node
def_use = build_def_use_chain(cfg.nodes)
for chain in get_vulnerability_chains(
source.cfg_node,
sink.cfg_node,
def_use
):
vulnerability_type = how_vulnerable(
chain,
blackbox_mapping,
sanitiser_nodes,
potential_sanitiser,
cfg.blackbox_assignments,
ui_mode,
vuln_deets
)
if vulnerability_type == VulnerabilityType.FALSE:
continue
if ui_mode != UImode.NORMAL:
vuln_deets['reassignment_nodes'] = chain
return vuln_factory(vulnerability_type)(**vuln_deets)
return None
def find_vulnerabilities_in_cfg(
cfg,
definitions,
lattice,
ui_mode,
blackbox_mapping,
vulnerabilities_list
):
"""Find vulnerabilities in a cfg.
Args:
cfg(CFG): The CFG to find vulnerabilities in.
definitions(trigger_definitions_parser.Definitions): Source and sink definitions.
lattice(Lattice): the lattice we're analysing.
ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all.
blackbox_mapping(dict): A map of blackbox functions containing whether or not they propagate taint.
vulnerabilities_list(list): That we append to when we find vulnerabilities.
"""
triggers = identify_triggers(
cfg,
definitions.sources,
definitions.sinks,
lattice
)
for sink in triggers.sinks:
for source in triggers.sources:
vulnerability = get_vulnerability(
source,
sink,
triggers,
lattice,
cfg,
ui_mode,
blackbox_mapping
)
if vulnerability:
vulnerabilities_list.append(vulnerability)
def find_vulnerabilities(
cfg_list,
analysis_type,
ui_mode,
vulnerability_files
):
"""Find vulnerabilities in a list of CFGs from a trigger_word_file.
Args:
cfg_list(list[CFG]): the list of CFGs to scan.
analysis_type(AnalysisBase): analysis object used to create lattice.
ui_mode(UImode): determines if we interact with the user or trim the nodes in the output, if at all.
vulnerability_files(VulnerabilityFiles): contains trigger words and blackbox_mapping files
Returns:
A list of vulnerabilities.
"""
vulnerabilities = list()
definitions = parse(vulnerability_files.triggers)
with open(vulnerability_files.blackbox_mapping) as infile:
blackbox_mapping = json.load(infile)
for cfg in cfg_list:
find_vulnerabilities_in_cfg(
cfg,
definitions,
Lattice(cfg.nodes, analysis_type),
ui_mode,
blackbox_mapping,
vulnerabilities
)
with open(vulnerability_files.blackbox_mapping, 'w') as outfile:
json.dump(blackbox_mapping, outfile, indent=4)
return vulnerabilities
"""This module contains vulnerability types and helpers.
It is only used in vulnerabilities.py
"""
from enum import Enum
class VulnerabilityType(Enum):
FALSE = 0
SANITISED = 1
TRUE = 2
UNKNOWN = 3
def vuln_factory(vulnerability_type):
if vulnerability_type == VulnerabilityType.UNKNOWN:
return UnknownVulnerability
elif vulnerability_type == VulnerabilityType.SANITISED:
return SanitisedVulnerability
else:
return Vulnerability
def _get_reassignment_str(reassignment_nodes):
reassignments = ''
if reassignment_nodes:
reassignments += '\nReassigned in:\n\t'
reassignments += '\n\t'.join([
'File: ' + node.path + '\n' +
'\t > Line ' + str(node.line_number) + ': ' + node.label
for node in reassignment_nodes
])
return reassignments
class Vulnerability():
def __init__(
self,
source,
source_trigger_word,
sink,
sink_trigger_word,
reassignment_nodes
):
"""Set source and sink information."""
self.source = source
self.source_trigger_word = source_trigger_word
self.sink = sink
self.sink_trigger_word = sink_trigger_word
self.reassignment_nodes = reassignment_nodes
self._remove_sink_from_secondary_nodes()
def _remove_sink_from_secondary_nodes(self):
try:
self.reassignment_nodes.remove(self.sink)
except ValueError: # pragma: no cover
pass
def __str__(self):
"""Pretty printing of a vulnerability."""
reassigned_str = _get_reassignment_str(self.reassignment_nodes)
return (
'File: {}\n'
' > User input at line {}, trigger word "{}":\n'
'\t {}{}\nFile: {}\n'
' > reaches line {}, trigger word "{}":\n'
'\t{}'.format(
self.source.path,
self.source.line_number, self.source_trigger_word,
self.source.label, reassigned_str, self.sink.path,
self.sink.line_number, self.sink_trigger_word,
self.sink.label
)
)
def as_dict(self):
return {
'source': self.source.as_dict(),
'source_trigger_word': self.source_trigger_word,
'sink': self.sink.as_dict(),
'sink_trigger_word': self.sink_trigger_word,
'type': self.__class__.__name__,
'reassignment_nodes': [node.as_dict() for node in self.reassignment_nodes]
}
class SanitisedVulnerability(Vulnerability):
def __init__(
self,
confident,
sanitiser,
**kwargs
):
super().__init__(**kwargs)
self.confident = confident
self.sanitiser = sanitiser
def __str__(self):
"""Pretty printing of a vulnerability."""
return (
super().__str__() +
'\nThis vulnerability is ' +
('' if self.confident else 'potentially ') +
'sanitised by: ' +
str(self.sanitiser)
)
def as_dict(self):
output = super().as_dict()
output['sanitiser'] = self.sanitiser.as_dict()
output['confident'] = self.confident
return output
class UnknownVulnerability(Vulnerability):
def __init__(
self,
unknown_assignment,
**kwargs
):
super().__init__(**kwargs)
self.unknown_assignment = unknown_assignment
def as_dict(self):
output = super().as_dict()
output['unknown_assignment'] = self.unknown_assignment.as_dict()
return output
def __str__(self):
"""Pretty printing of a vulnerability."""
return (
super().__str__() +
'\nThis vulnerability is unknown due to: ' +
str(self.unknown_assignment)
)
graphviz>=0.4.10
requests>=2.12
GitPython>=2.0.8

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet