::: {#345402d9 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}

from nbdev.showdoc import *

:::

::: {#d637dc48 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import re
from typing import List,Callable,Tuple,Union,TypedDict
import torch
from torch import nn
import torch.nn.functional as F
from torchtyping import TensorType
from einops import rearrange
from toolformer.api import BaseApI

:::

::: {#20796764 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class AugmentedCandidate(TypedDict):
    api_start_positions:int

:::

::: {#e3c6e019 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DataGenerator(nn.Module):
    def __init__(
        self,
        config: dict,
        model: Callable, tokenizer: Callable,
        apis: List[BaseApI],
        #device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device=torch.device("cpu")
    ):
        super().__init__()
        start_character = config["data_generator"]["api_start_ch"]
        end_character = config["data_generator"]["api_end_ch"]
        output_character = config["data_generator"]["api_out_ch"]
        
        # add a space, because when the model generate a token, it's also include a "space"
        self.api_start_token_id = tokenizer(f' {start_character}', return_tensors="pt")["input_ids"][0]
        self.api_end_token_id = tokenizer(end_character, return_tensors="pt")["input_ids"][0]
        self.api_output_token_id = tokenizer(f'{output_character}', return_tensors="pt")["input_ids"][0]
        
        self.top_k_sampling = config["data_generator"]["top_k_sampling"]
        self.sampling_threshold = config["data_generator"]["sampling_threshold"]
        self.filtering_threshold = config["data_generator"]["filtering_threshold"]
        
        self.apis = apis
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        
        # TODO: handle for cases that the sentence contains ".\n\n"
        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer(".\n\n")["input_ids"][0]
    
    def sample_api_position(
        self,
        prompt_ids: TensorType["seq_len"], # the ids of the prompt
    ) -> Tuple[
        TensorType["n_positions"], # The positions of api call
        TensorType["seq_len"] # The generated text
    ]:
        """Sampling API positions."""
        # TODO: add support batch
        # the ids of the prompt and generated_ids
        prompt_and_generated_ids = prompt_ids
        # only the ids of the generated_ids
        generated_ids = torch.tensor([]).to(self.device)
        i = torch.tensor([0]).to(self.device)
        
        api_pos_probs = torch.tensor([])
        
        with torch.no_grad():    
            while True:
                logits = self.model(
                    input_ids=prompt_and_generated_ids.unsqueeze(0),
                ).logits

                last_logit = logits[0, -1, :]
                probs = torch.softmax(last_logit, dim=-1)
                api_start_prob = probs[self.api_start_token_id]
                
                if api_start_prob > self.sampling_threshold:
                    api_pos_probs = torch.cat([
                        api_pos_probs,
                        torch.tensor([api_start_prob, i]).unsqueeze(0)
                    ], dim=0)     
                
                # sampling a token
                # next_token = torch.multinomial(probs, num_samples=1)
                next_token = torch.argmax(probs, dim=-1)
                next_token = next_token.unsqueeze(0)
                
                prompt_and_generated_ids = torch.cat([prompt_and_generated_ids, next_token], dim=0)
                generated_ids = torch.cat([generated_ids, next_token], dim=0)
                
                if next_token == self.eos_token_id:
                    break
                else:
                    i += 1
        
        if api_pos_probs.numel() == 0:
            api_positions = torch.tensor([])
        else:
            _, indices = torch.sort(api_pos_probs[:, 0], descending=True)
            top_k_sampling = self.top_k_sampling
            api_positions = api_pos_probs[indices[:top_k_sampling], 1]
                    
        return api_positions.long(), generated_ids.long()

    def obtain_api_response(
        self,
        prompt_ids: TensorType["seq_len"],
        positions: TensorType["n_positions"],
        generated_ids: TensorType["seq_len"]
    ) -> TensorType["n_positions", "seq_len"]:
        
        MAX_PAD = 50
        
        # the ids before the start of an api call
        pre_api_ids = torch.tensor([])

        for position in positions:
            text_ids = torch.cat([generated_ids[:position], self.api_start_token_id], dim=0)
            padded_text_ids = F.pad(text_ids, pad=(MAX_PAD - text_ids.shape[-1], 0), value=self.pad_token_id)
            
            pre_api_ids = torch.cat([
                pre_api_ids,
                rearrange(padded_text_ids, "... -> 1 ...")
            ])
        
        PROMPT_LENGTH = len(prompt_ids)
        
        # TODO: optimzie this
        prompt_and_pre_api_ids = torch.tensor([])
        for x in pre_api_ids:
            prompt_and_pre_api_ids = torch.cat([
                prompt_and_pre_api_ids,
                torch.cat([prompt_ids, x]).unsqueeze(0)
            ], dim=0)
                     
        with torch.no_grad():
            candidate_ids = self.model.generate(
                input_ids=prompt_and_pre_api_ids.long(),
                eos_token_id=self.eos_token_id,
                max_new_tokens=50,
            )
        
        # filter out the prompt template
        # only keep the generated ids
        candidate_ids = candidate_ids[:, PROMPT_LENGTH:]
        
        return candidate_ids
    
    def _generate_conditioning_prompts(
        self,
        api: BaseApI,
        candidate_ids: TensorType["n_candidates", "seq_len"],
    ):
        conditioning_api_ids = torch.tensor([])

        API_NAME = api.name
        MAX_PAD = 100
        
        def extract_api_request_content(text: str, api_name: str) -> str:
            """Extract the content of an API request from a given text."""
            start_tag = f"{api_name}("
            end_tag = ")"
            start_idx = text.find(start_tag)
            if start_idx == -1:
                return None
            start_idx += len(start_tag)
            end_idx = text.find(end_tag, start_idx)
            if end_idx == -1:
                return None
            return text[start_idx:end_idx]
        
        def extract_api_syntax(text: str, api_name: str) -> str:
            """Extract the API Syntax from a given text."""
            pattern = r"\[{}\(.*?\)\]".format(api_name)
            matches = re.findall(pattern, text)
            return matches

        for text_ids in candidate_ids:
            # the ids of the prediction
            text = self.tokenizer.decode(text_ids, skip_special_tokens=True)
            
            api_request_content = extract_api_request_content(text, api_name=API_NAME)
            api_response = api(api_request_content)
            api_response_ids = self.tokenizer(api_response, return_tensors="pt")["input_ids"][0]
            # Format: "-> [api_response]"
            api_response_with_arrow_ids = torch.cat([self.api_output_token_id, api_response_ids], dim=0)
            
            api_syntax = extract_api_syntax(text, api_name=API_NAME)
            api_syntax_ids = self.tokenizer(api_syntax, return_tensors="pt")["input_ids"][0]
            api_syntax_with_response_ids = torch.cat([api_syntax_ids[:-1], api_response_with_arrow_ids, api_syntax_ids[-1:]])
            api_syntax_without_response_ids = torch.cat([api_syntax_ids[:-1], self.api_output_token_id, api_syntax_ids[-1:]])
                              
            padded_api_without_response = rearrange(
                F.pad(api_syntax_without_response_ids, pad=((MAX_PAD - api_syntax_without_response_ids.shape[-1]), 0), value=self.pad_token_id),
                "... -> 1 ..."
            )
            padded_api_with_response = rearrange(
                F.pad(api_syntax_with_response_ids, pad=((MAX_PAD - api_syntax_with_response_ids.shape[-1]), 0), value=self.pad_token_id),
                "... -> 1 ..."
            )
        
            padded_api_call = torch.cat([
                padded_api_without_response,
                padded_api_with_response
            ], dim=0)
            padded_api_call = rearrange(padded_api_call, "... -> 1 ...")
            
            conditioning_api_ids = torch.cat([conditioning_api_ids, padded_api_call], dim=0).long()
                    
        return conditioning_api_ids

    def _filter_candidate_by_threshold(
        self,
        losses,
        candidates: TensorType["seq_len"]
    ):
        filtered_augmented_text_ids = torch.tensor([])
        for i, position in enumerate(losses):
            negative_loss = min(losses[position][0], losses[position][1])
            positive_loss = losses[position][2]
            
            if negative_loss - positive_loss >= self.filtering_threshold:
                # filtered_augmented_text_ids.append(candidates[i])
                filtered_augmented_text_ids = torch.cat([
                    filtered_augmented_text_ids,
                    candidates[i].unsqueeze(0)
                ], dim=0)
        
        return filtered_augmented_text_ids.long()

    def filter_api( 
        self,
        api: BaseApI,
        text_ids: TensorType["seq_len"],
        api_start_idxs: TensorType["n_positions"],
        candidate_ids: TensorType["n_positions", "seq_len"]
    ):
        conditioning_api_ids = self._generate_conditioning_prompts(api, candidate_ids)
                
        SPACE_TOKEN = self.tokenizer(". ", return_tensors="pt")["input_ids"][0]
        API_LENGTH = 100
        augmented_text_ids = {"api_start_positions": {}}
        
        def _compute_weight(t: int) -> Union[int, float]:
            """Compute the weight in the loss function."""
            return max(0, 1-0.2*t)
        
        for idx, api_ids in zip(api_start_idxs, conditioning_api_ids):
            idx = idx.item()
            seq_len = len(text_ids)
            augmented_text_ids["api_start_positions"][idx] = {
                "seq_positions": {}
            }

            j = idx
            while j <= seq_len - 1:
                # if the model predic
                if j == 1:
                    j += 1
                    continue
                
                # in the formua, from x_1 to x_j (include x_j)
                # => generate_ids[:j]
                conditioning_text_ids = text_ids[:j]
                api_and_text_ids = torch.stack([
                    F.pad(conditioning_text_ids, pad=(API_LENGTH + len(SPACE_TOKEN), 0), value=self.pad_token_id), # [text_ids]
                    torch.cat([api_ids[0], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->, text_ids]
                    torch.cat([api_ids[1], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->result, text_ids]
                ], dim=0)
                                
                # the next token after x_j
                next_token_ids = text_ids[j]
                augmented_text_ids["api_start_positions"][idx]["seq_positions"][j] = {
                    "prompt_ids": api_and_text_ids,
                    "unnormalized_weight": _compute_weight(t=j-idx),
                    "losses": [],
                    "target_ids": torch.tensor([next_token_ids, next_token_ids, next_token_ids])
                }
                j += 1
        
        def _normalize_weights(augmented_text_ids):
            """Normalize the weight of each position in a sequence."""
            for api_start_position in augmented_text_ids["api_start_positions"].values():
                total_weight = sum([seq_position["unnormalized_weight"] for seq_position in api_start_position["seq_positions"].values()])
                for seq_position in api_start_position["seq_positions"].values():
                    seq_position["normalized_weight"] = seq_position["unnormalized_weight"] / total_weight
            
            return augmented_text_ids
        
        augmented_text_ids = _normalize_weights(augmented_text_ids)
                
        def extract_conditioning_ids_and_target_ids(augmented_text_ids):
            conditioning_text_ids = torch.tensor([])
            target_ids = torch.tensor([])
            
            for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
                for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
                    target_ids = torch.concat([target_ids, seq_position_dict["target_ids"]], dim=0)
                    for prompt_id in seq_position_dict["prompt_ids"]:
                        conditioning_text_ids = torch.cat([
                            conditioning_text_ids,
                            F.pad(prompt_id.long(), pad=(50-prompt_id.shape[-1], 0), value=self.pad_token_id).unsqueeze(0)
                        ], dim=0)
        
            return conditioning_text_ids.long(), target_ids.long()

        conditioning_text_ids, target_ids = extract_conditioning_ids_and_target_ids(augmented_text_ids)
            
        output = self.model(input_ids=conditioning_text_ids.long())
        logits = output.logits[:, -1, :]
                    
        def extract_target_logprob_from_logits(logits, target_ids):
            log_probs = F.log_softmax(logits, dim=-1)
            target_log_probs = log_probs[range(target_ids.shape[-1]), target_ids]
            return target_log_probs

        log_probs = extract_target_logprob_from_logits(logits, target_ids)
            
        for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
            for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
                seq_position_dict["losses"] = log_probs[:3].squeeze(0)
                log_probs = log_probs[3:]
        
        def _calculate_weighted_loss(augmented_text_ids):
            for position in augmented_text_ids["api_start_positions"]:        
                seq_positions = augmented_text_ids["api_start_positions"][position]["seq_positions"]
                for i in seq_positions:
                    losses = seq_positions[i]["losses"]
                    weights = seq_positions[i]["normalized_weight"]
                    seq_positions[i]["weighted_losses"] = -losses * weights
            
            return augmented_text_ids
        
        augmented_text_ids = _calculate_weighted_loss(augmented_text_ids)
        
        def _calculate_loss(augmented_text_ids):
            data = {}
            for position in augmented_text_ids["api_start_positions"]:        
                seq_positions = augmented_text_ids["api_start_positions"][position]["seq_positions"]
                losses = [0, 0, 0]            
                for i in seq_positions:
                    losses[0] += seq_positions[i]["weighted_losses"][0] # loss for [text]
                    losses[1] += seq_positions[i]["weighted_losses"][1] # loss for [api->, text]
                    losses[2] += seq_positions[i]["weighted_losses"][2] # loss for [api-result, text]
                data[position] = losses
                
            return data
        
        losses = _calculate_loss(augmented_text_ids)
        filtered_candidate_ids = self._filter_candidate_by_threshold(losses, candidate_ids)
        return filtered_candidate_ids
    
    def generate(
        self,
        text: str,
    ) -> TensorType["n_apis", "n_candidates", "seq_len"]:
        filtered_apis = torch.tensor([])
        
        for api in self.apis:
            # TODO: add support batch
            prompt = api.prompt_template.format(input=text)
            prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
        
            # sampling positions
            api_start_idxs, generated_ids = self.sample_api_position(prompt_ids)
            
            # obtaining api responses
            candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids)

            # filtering
            text_ids = self.tokenizer(text, return_tensors="pt")["input_ids"][0]
            
            # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids
            filtered_candidate_ids = self.filter_api(api, text_ids, api_start_idxs, candidate_ids)
                    
            filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0)
        
        return filtered_apis.long()

:::