nnsight.contexts#

The contexts module contains logic for managing the tracing and running of models with nnsight.tracing and nnsight.envoy

The primary two classes involved here are Tracer and Invoker.

The Tracer class creates a Graph around the underlying model of an NNsight. The graph tracks and manages the operations performed on the inputs and outputs of said model. Module’s envoys in the model expose their .output and .input attributes which when accessed, add to the computation graph of the tracer. To do this, they need to know about the current Tracer object, so each module’s envoy’s .tracer object is set to be the current Tracer.

The Tracer object also keeps track of the batch_size of the most recent input, the generation index for multi iteration generations, and all of the inputs made during its context in the .batched_input attribute. Inputs added to this attribute should be in a format where each index is a batch index and allows the model to batch all of the inputs together.

This is to keep things consistent. If two different inputs are in two different valid formats, they both become the same format and are easy to batch. In the case of LanguageModels, regardless of whether the input are string prompts, pre-processed dictionaries, or input ids, the batched input is only input ids. On exiting the Tracer context, the Tracer object should use the information and inputs provided to it to carry out the execution of the model.

The Invoker class is what actually accepts inputs to the model/graph, and it updates its parent Tracer object with the appropriate information about respective inputs. On entering the invoker context with some input, the invoker leverages the model to pre-process and prepare the input to the model. Using the prepared inputs, it updates its Tracer object with a batched version of the input, the size of the batched input, and the current generation index. It also runs a ‘meta’ version of the input through the model’s meta_model. This updates the sizes/dtypes of all of the module’s Envoys inputs and outputs based on the characteristics of the input.

nnsight comes with an extension of a Tracer, RemoteTracer, which enables both local and remote execution.

nnsight.contexts.check_for_dependencies(data: Any) Tuple[Any, bool][source]#

Checks to see if there are any Proxies in data. If so, convert them to a Bridge Node, then a Lock Node in order to later get the value of the Bridge Node come execution.

Parameters:

data (Any) – Data to check for Proxies.

Returns:

Data with Proxies replaced with Bridge/Lock Nodes. bool: If there were any proxies in data.

Return type:

Any

nnsight.contexts.resolve_dependencies(data: Any) Any[source]#

Turn any dependencies (Locked Bridge Node) within data into their value. Executes the Bridge Node.

Parameters:

data (Any) – Data to find and resolve dependencies within.

Returns:

Data with dependencies converted to their value.

Return type:

Any

class nnsight.contexts.Tracer.Tracer(backend: Backend, model: NNsight, validate: bool = False, graph: Graph = None, bridge: Bridge = None, return_context: bool = False, **kwargs)[source]#

The Tracer class creates a nnsight.tracing.Graph.Graph around the ._model of a nnsight.models.NNsightModel.NNsight which tracks and manages the operations performed on the inputs and outputs of said model.

_model#

nnsight Model object that ths context manager traces and executes.

Type:

nnsight.models.NNsightModel.NNsight

_graph#

Graph which traces operations performed on the input and output of modules’ Envoys are added and later executed.

Type:

nnsight.tracing.Graph.Graph

_args#

Positional arguments to be passed to function that executes the model.

Type:

List[Any]

_kwargs#

Keyword arguments to be passed to function that executes the model.

Type:

Dict[str,Any]

_invoker_inputs#

Inputs for each invocation of this Tracer.

Type:

List[Any]

_invoker#

Currently open Invoker.

Type:

Invoker

edit_backend_execute() Graph[source]#

Should execute this object locally and return a result that can be handled by EditMixin objects.

Returns:

Result containing data to return from a edit execution.

Return type:

Any

invoke(*inputs: Any, **kwargs) Invoker[source]#

Create an Invoker context dor a given input.

Raises:

Exception – If an Invoker context is already open

Returns:

Invoker.

Return type:

Invoker

local_backend_execute() Graph[source]#

Should execute this object locally and return a result that can be handled by RemoteMixin objects.

Returns:

Result containing data to return from a remote execution.

Return type:

Any

next(increment: int = 1) None[source]#

Increments call_iter of all module Envoys. Useful when doing iterative/generative runs.

Parameters:

increment (int) – How many call_iter to increment at once. Defaults to 1.

remote_backend_get_model_key() str[source]#

Should return the model_key used to specify which model to run on the remote service.

Returns:

Model key.

Return type:

str

remote_backend_handle_result_value(value: Dict[str, Any]) None[source]#

Should handle postprocessed result from remote_backend_postprocess_result on return from remote service.

Parameters:

value (Any) – Result.

remote_backend_postprocess_result(local_result: Graph) Dict[str, Any][source]#

Should handle postprocessing the result from local_backend_execute.

For example moving tensors to cpu/detaching/etc.

Parameters:

local_result (Any) – Local execution result.

Returns:

Post processed local execution result.

Return type:

Any

class nnsight.contexts.Invoker.Invoker(tracer: Tracer, *inputs: Any, scan: bool = False, **kwargs)[source]#

An Invoker is meant to work in tandem with a nnsight.contexts.Tracer.Tracer to enter input and manage intervention tracing.

tracer#

Tracer object to enter input and manage context.

Type:

nnsight.contexts.Tracer.Tracer

inputs#

Initially entered inputs, then post-processed inputs from model’s ._prepare_inputs(…) method.

Type:

tuple[Any]

scan#

If to execute the model using FakeTensor in order to update the potential sizes/dtypes of all modules’ Envoys’ inputs/outputs as well as validate things work correctly. Scanning is not free computation wise so you may want to turn this to false when running in a loop. When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to False.

Type:

bool

kwargs#

Keyword arguments passed to the model’s _prepare_inputs method.

Type:

Dict[str,Any]

scanning#

If currently scanning.

Type:

bool

class nnsight.contexts.GraphBasedContext.GlobalTracingContext[source]#

The Global Tracing Context handles adding tracing operations globally without reference to a given GraphBasedContext. There should only be one of these and that is GlobalTracingContext.GLOBAL_TRACING_CONTEXT. GlobalTracingContext.TORCH_HANDLER handles adding torch functions without reference to a given GraphBasedContext.

class GlobalTracingExit[source]#
class GlobalTracingTorchHandler[source]#
static deregister() None[source]#

Deregister GraphBasedContext globally.

Parameters:

graph_based_context (GraphBasedContext) – GraphBasedContext to deregister.

static register(graph_based_context: GraphBasedContext) None[source]#

Register GraphBasedContext globally.

Parameters:

graph_based_context (GraphBasedContext) – GraphBasedContext to register.

static try_deregister(graph_based_context: GraphBasedContext) bool[source]#

Attempts to deregister a Graph globally. Will not if graph_based_context does not have the same Graph as the currently registered one.

Parameters:

graph_based_context (GraphBasedContext) – GraphBasedContext to deregister.

Returns:

True if deregistering ws successful, False otherwise.

Return type:

bool

static try_register(graph_based_context: GraphBasedContext) bool[source]#

Attempts to register a Graph globally.] Will not if one is already registered.

Parameters:

graph_based_context (GraphBasedContext) – GraphBasedContext to register.

Returns:

True if registering ws successful, False otherwise.

Return type:

bool

class nnsight.contexts.GraphBasedContext.GraphBasedContext(backend: Backend, graph: Graph = None, bridge: Bridge = None, **kwargs)[source]#
apply(target: Callable, *args, validate: bool = None, **kwargs) InterventionProxy[source]#

Helper method to directly add a function to the intervention graph.

Parameters:
  • target (Callable) – Function to apply

  • validate (bool) – If to try and run this operation in FakeMode to test it out and scan it.

Returns:

Proxy of applying that function.

Return type:

InterventionProxy

bool(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable bool.

bridge_backend_handle(bridge: Bridge) None[source]#

Should add self to the current Bridge in some capacity.

Parameters:

bridge (Bridge) – Current Bridge.

bytearray(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable bytearray.

bytes(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable bytes.

complex(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable complex number.

cond(condition: InterventionProxy | Any) Conditional[source]#
Entrypoint to the Conditional context.

Takes in a condition argument which acts as the dependency of the Conditional node in the Intervention graph. The condition is evaluated as a boolean, and if True, executes all the interventions defined within the body of the conditional context.

Parameters:

condition (Union[InterventionProxy, Any]) – Dependency of the Conditional Node.

Returns:

Conditional context object.

Return type:

Conditional

Example

Setup:

Ex 1: The .save() on the model output will only be executed if the condition passed to tracer.cond() is evaluated to True.

Ex 2: The condition is on an InterventionProxy which creates in return an InterventionProxy

dict(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable dictionary.

exit() InterventionProxy[source]#

Exits the execution of a sequential intervention graph.

Returns:

Proxy of the EarlyStopProtocol node.

Return type:

InterventionProxy

float(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable float.

int(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable int.

list(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable list.

log(*data: Any) None[source]#

Adds a node via .apply to print the value of a Node.

Parameters:

data (Any) – Data to print.

set(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable set.

str(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable string.

tuple(*args, **kwargs) InterventionProxy[source]#

NNsight helper method to create a traceable tuple.

vis(**kwargs) None[source]#

Helper method to save a visualization of the current state of the intervention graph.

class nnsight.contexts.Conditional.Conditional(graph: Graph, condition: 'InterventionProxy' | Any)[source]#

A context defined by a boolean condition, upon which the execution of all nodes defined from within is contingent.

_graph#

Conditional Context graph.

Type:

Graph

_condition#

Condition.

Type:

Union[InterventionProxy, Any]

class nnsight.contexts.Conditional.ConditionalManager[source]#

A Graph attachement that manages the Conditional contexts defined within an Intervention Graph.

_conditional_dict#

Mapping of ConditionalProtocol node name to Conditional context.

Type:

Dict[str, Node]

_conditioned_nodes_dict#

Mapping of ConditionalProtocol node name to all the Nodes conditiones by it.

Type:

Dict[str, Set[Node]]

_conditional_stack#

Stack of visited Conditional contexts’ ConditonalProtocol nodes.

Type:

Dict

add_conditioned_node(node: Node) None[source]#

Adding a Node to the set of conditioned nodes by the current Conditonal context.

Parameters:

node (-) – A node conditioned by the latest Conditional context.

get(key: str) Conditional[source]#

Returns a ConditionalProtocol node.

Parameters:

key (str) – ConditionalProtocol node name.

Returns:

ConditionalProtocol node.

Return type:

Node

is_node_conditioned(node: Node) bool[source]#

Returns True if the Node argument is conditioned by the current Conditional context.

Parameters:

node (-) – Node.

Returns:

Whether the Node is conditioned.

Return type:

bool

peek() 'Node' | None[source]#

Gets the current Conditional context’s ConditionalProtocol node.

Returns:

Lastest ConditonalProtocol node if the ConditionalManager stack is non-empty.

Return type:

Optional[Node]

pop() None[source]#

Pops the ConditionalProtocol node of the current Conditional context from the ConditionalManager stack.

push(conditional_node: Node) None[source]#

Adds the Conditional to the stack of Conditional contexts.

Parameters:

conditional_node (Node) – ConditionalProtocol node.

class nnsight.contexts.session.Session.Session(backend: Backend, model: NNsight, *args, bridge: Bridge = None, **kwargs)[source]#

A Session is a root Collection that handles adding new Graphs and new Collections while in the session.

bridge#

Bridge object which stores all Graphs added during the session and handles interaction between them

Type:

Bridge

graph#

Root Graph where operations and values meant for access by all subsequent Graphs should be stored and referenced.

Type:

Graph

model#

NNsight model.

Type:

NNsight

backend#

Backend for this context object.

Type:

Backend

iter(iterable: Iterable, **kwargs) Iterator[source]#

Creates an Iterator context to iteratively execute an intervention graph, with an update item at each iteration.

Parameters:
  • iterable (-) – Data to iterate over.

  • return_context (-) – If True, returns the Iterator context. Default: False.

Returns:

Iterator context.

Return type:

Iterator

Example

Setup:
Ex:
local_backend_execute() Dict[int, Graph][source]#

Should execute this object locally and return a result that can be handled by RemoteMixin objects.

Returns:

Result containing data to return from a remote execution.

Return type:

Any

remote_backend_get_model_key() str[source]#

Should return the model_key used to specify which model to run on the remote service.

Returns:

Model key.

Return type:

str

remote_backend_handle_result_value(value: Dict[int, Dict[str, Any]])[source]#

Should handle postprocessed result from remote_backend_postprocess_result on return from remote service.

Parameters:

value (Any) – Result.

remote_backend_postprocess_result(local_result: Dict[int, Graph])[source]#

Should handle postprocessing the result from local_backend_execute.

For example moving tensors to cpu/detaching/etc.

Parameters:

local_result (Any) – Local execution result.

Returns:

Post processed local execution result.

Return type:

Any

class nnsight.contexts.session.Iterator.Iterator(data: Iterable, *args, return_context: bool = False, **kwargs)[source]#

Intervention loop context for iterative execution of an intervention graph.

- data

Data to iterate over.

Type:

Iterable

- return_context

If True, returns the Iterator object upon entering the Iterator context.

Type:

bool