::: {#1bd9f3f1 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}
from nbdev.showdoc import *:::
::: {#35c24500 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from typing import Optional, List
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchtyping import TensorType
from einops import rearrange
from toolformer.api import BaseApI
from toolformer.utils import extract_api_content, extract_api_name:::
::: {#34a1bc38 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class ToolFormer(nn.Module):
def __init__(
self,
model: AutoModelForCausalLM,
apis: List[BaseApI],
config: dict
):
super().__init__()
self.model = model
self.apis = apis
self.config = config
self.is_calling_api: bool = False
# TODO: make a config class contains token_id
tokenizer = AutoTokenizer.from_pretrained(self.config["tokenizer"]["path"])
self.tokenizer = tokenizer # TODO: remove after debug
start_character = config["data_generator"]["api_start_character"]
end_character = config["data_generator"]["api_end_character"]
output_character = config["data_generator"]["api_output_character"]
self.api_start_token_id = tokenizer(f' {start_character}', return_tensors="pt")["input_ids"][0]
self.api_end_token_id = tokenizer(end_character, return_tensors="pt")["input_ids"][0]
self.api_output_token_id = tokenizer(f'{output_character}', return_tensors="pt")["input_ids"][0]
self.eos_token_ids = tokenizer(
[".", ".\n\n"],
return_tensors="pt"
)["input_ids"].squeeze()
# TODO: support batch
self.api_request_content: torch.Tensor = torch.tensor([])
def _sampling(self, probs: TensorType["batch_size", "seq_len"]) -> TensorType["batch_size", "seq_len"]:
return torch.argmax(probs, dim=-1)
def execute_api(self, text_ids: TensorType["seq_len"]) -> Optional[TensorType["seq_len"]]:
"""Execute an API call."""
text = self.tokenizer.decode(text_ids, skip_special_tokens=True)
api_name = extract_api_name(text, is_end_token=False)
if api_name is not None:
# find does apis contains the api_name
for api in self.apis:
if api.name == api_name:
api_content = extract_api_content(text, api_name=api_name)
api_output = api(api_content)
return self.tokenizer(api_output, return_tensors="pt")["input_ids"][0]
return None
def add_idx_to_api_request_content(self, idx: TensorType[1]):
self.api_request_content = torch.cat([
self.api_request_content,
rearrange(idx, '... -> 1 ...')
], dim=-1).long()
def forward(
self,
input_ids: TensorType["batch_size", "seq_len"],
attention_mask: Optional[TensorType["batch_size", "seq_len"]]=None,
max_new_tokens: int = 10,
**kwargs
) -> TensorType["batch_size", "seq_len"]:
# check padding to the left
generated_ids = input_ids
for _ in range(max_new_tokens):
output_ids = self.model(
input_ids=generated_ids,
attention_mask=attention_mask,
**kwargs
)
logits = output_ids.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# TODO: k should be a config
_, top_k_idx = torch.topk(probs, k=1, dim=-1)
if self.is_calling_api is True:
if self.api_end_token_id in top_k_idx:
# if the api end token is in the top_k_idx, then we will execute the api
# and then add api_end_token_id to the generated_ids
# TODO: add support batch
api_output_ids = self.execute_api(self.api_request_content[0])
if api_output_ids is not None:
pred_ids = torch.cat([
self.api_output_token_id,
api_output_ids,
self.api_end_token_id
], dim=-1).long()
else:
pred_ids = self.api_end_token_id
self.is_calling_api = False
else:
pred_ids = self._sampling(probs)
self.add_idx_to_api_request_content(pred_ids)
else:
if self.api_start_token_id in top_k_idx:
# if the api start token is in the top_k_idx, then we are calling an api
self.is_calling_api = True
pred_ids = self.api_start_token_id
self.add_idx_to_api_request_content(pred_ids)
else:
pred_ids = self._sampling(probs)
generated_ids = torch.cat([
generated_ids,
rearrange(pred_ids, '... -> 1 ...')
], dim=1)
attention_mask = torch.cat([
attention_mask,
rearrange(torch.ones_like(pred_ids), '... -> 1 ...')
], dim=1)
# ignore the case that pred_ids contains api_output
if len(pred_ids) == 1 and pred_ids in self.eos_token_ids:
break
return generated_ids:::