Source code for code_genie.genie

import random
import re
from typing import Any, Callable, Dict, List, Optional, Union

import pandas as pd
from pydantic import BaseModel

from code_genie._cache import _CacheManager, _CacheValue
from code_genie.client import Client


[docs]class GenieResult(BaseModel): """The result of a genie execution""" id: str """ID of the genie, this would also be the filename used for storing the generated code in the cache.""" code: str """The code generated by the genie""" cache_dir: str """The cache directory used by the genie""" result: Any = None """The result of the execution; None if no result was returned""" class Config: # always exclude result from json export fields = {"result": {"exclude": True}} frozen = True
[docs]class Genie: _hash_sep = "::"
[docs] def __init__( self, data: Optional[Any] = None, client: Optional[Client] = None, cache_dir: Optional[str] = None, copy_data_before_use: bool = True, ): """Initialize a genie instance Args: data: a base dataset whose attributes will be used to generate the code. the result will be determined by running this data over the code client: an instance of the client to use for making requests to the api. if not provided, a new instance will be created. cache_dir: if provided, the code generated by the genie will be cached in this directory. if not provided, the global default is used. it is recommended to use set_cache_dir() method to set this. copy_data_before_use: if True, the data will be copied before passing through generated code. this is to prevent the data from being modified inplace by the code. the data passed should have a copy() method implemented. if False, the data will be passed as is. this is faster but can lead to unexpected results Returns: A callable which can be used to execute the code generated by the genie. """ self.data = data self._base_key = self._get_data_key(data) self._cache = _CacheManager(cache_dir) self.copy_data_before_use = copy_data_before_use if copy_data_before_use: # check data should have a copy method if not hasattr(data, "copy"): raise ValueError( "data should have a copy method implemented if copy_data_before_use is True", "Set it to False if you want to continue using the genie", ) self._client = client or Client()
[docs] def plz( self, instructions: Optional[Union[str, List[str]]], additional_inputs: Optional[Dict[str, Any]] = None, override: bool = False, update_base_input: bool = False, ) -> GenieResult: """Generate code for a new task Args: instructions: text instructions on the task required to be performed. use the keywords in inputs argument to refer to the inputs. additional_inputs: a dictionary of inputs to the function. the keys are the names of the inputs and the values are small description of the inputs. override: if a genie has been generated before with the same args, then it will be loaded from cache be default. set override to True to make a new API call and recreate the genie. update_base_input: if True, the base data will be replaced by the result of executing the code. this is used if we are making a permanent update to the input and want to use the updated input moving forward. Returns: A GenieResult instance which contains attributes: - result: the result of executing the code - id: the id of the genie - code: the code generated by the genie - cache_dir: the directory where the code is cached. the code will be cached in a file named "cache_dir/<id>.py """ if isinstance(instructions, str): instructions = [instructions] # check cache cache_key = self._get_hash_str(instructions, additional_inputs) cache_value = self._cache.get(cache_key) # case: reading from cache if (not override) and (cache_value is not None): code, id = cache_value.code, cache_value.id print(f"Loading cached genie id: {id}, set override = True to rerun") else: # case: creating new genie inputs = self._combine_inputs(additional_inputs) code = self._get_code(instructions, inputs) id = self._generate_id(code) self._update_cache(code, cache_key, instructions, inputs, id) return self._get_result(code, additional_inputs, update_base_input, id)
def _update_cache( self, code: str, cache_key: str, instructions: Optional[Union[str, List[str]]], inputs: Dict[str, Any], id: str ): self._cache.update( cache_key, _CacheValue(code=code, id=id, instructions=instructions, inputs=list(inputs.keys())), ) print(f"Genie cached with id: {id}") def _get_result(self, code, additional_inputs, update_base_input, id: Optional[str] = None): # create executor and return results result = self.run(code, additional_inputs) if update_base_input: if result is None: raise ValueError(f"result of genie is None, cannot update base input") self.data = result id = id or self._generate_id(code) return GenieResult(id=id, code=code, cache_dir=self._cache.cache_dir, result=result)
[docs] def custom( self, code: str, additional_inputs: Optional[Dict[str, Any]] = None, update_base_input: bool = False ) -> GenieResult: """Define a custom genie with user defined code segment. The first argument of the function should be the base input of the genie. Note that this code should define a stand alone function, ie, it should not depend on any external variables or functions or imports. If any additional packages are required, you need to import them in the code segment itself. Args: code: the code segment defining a single function to be used to process data. additional_inputs: a dictionary of inputs to the function. the keys are the names of the inputs and the values are small description of the inputs. update_base_input: if True, the base data will be replaced by the result of executing the code. this is used if we are making a permanent update to the input and want to use the updated input moving forward. Returns: A GenieResult instance which contains attributes: - result: the result of executing the code - id: the id of the genie - code: the code generated by the genie - cache_dir: the directory where the code is cached. the code will be cached in a file named "cache_dir/<id>.py """ # proxy instructions as the code entered instructions = [code] cache_key = self._get_hash_str(instructions, additional_inputs) id = self._generate_id(code) self._update_cache(code, cache_key, instructions, inputs=self._combine_inputs(additional_inputs), id=id) return self._get_result(code, additional_inputs, update_base_input)
def run(self, code: str, additional_inputs: Dict[str, Any]): executor = self._extract_executable(code) try: return executor(**self._combine_inputs(additional_inputs, copy_base_input=True)) except Exception as e: raise RuntimeError(f"Failed to execute code segment: \n\n{code}") from e @classmethod def _extract_fn_name(cls, code: str): # find function name from code block and substitute with function_name match = re.search("def\s+(.*)\(.*\).*:", code) if match is None: raise RuntimeError(f"Failed to extract function from code block: {code}") return match.group(1) def _combine_inputs( self, additional_inputs: Optional[Dict[str, Any]], copy_base_input: bool = False ) -> Dict[str, Any]: data = self.data.copy() if (copy_base_input and self.copy_data_before_use) else self.data return {self._base_key: data, **(additional_inputs or {})} @staticmethod def _create_input_str(x): if isinstance(x, pd.DataFrame): return f"pandas dataframes with columns: {x.columns}" return f"{type(x)}" def _get_code(self, instructions: List[str], inputs: Dict[str, Any]) -> str: input_str = {key: self._create_input_str(value) for key, value in inputs.items()} return self._client.get(instructions=instructions, inputs=input_str) @classmethod def _extract_executable(cls, code: str) -> Callable: # define function in memory fn_name = cls._extract_fn_name(code) mem = {} exec(code, mem) return mem[fn_name] @classmethod def _generate_id(cls, code: str) -> str: fn_name = cls._extract_fn_name(code) # use fn name with random 5 digit suffix return f"{fn_name}_{random.randint(10000, 99999)}" @classmethod def _list_to_str(cls, l: List[str]) -> str: return cls._hash_sep.join(l) @classmethod def _inputs_to_str(cls, d: Dict[str, Any]) -> str: sorted_keys = sorted(d.keys()) return cls._hash_sep.join([f"{k}={type(d[k])}" for k in sorted_keys]) def _get_hash_str(self, instructions: List[str], additional_inputs: Optional[Dict[str, Any]]) -> str: hash_strings = [ self._list_to_str(instructions), self._inputs_to_str(self._combine_inputs(additional_inputs)), ] return self._hash_sep.join(hash_strings)
[docs] def read_cache(self) -> Dict[str, str]: """Read all the code segments in the cache directory set in the current genie instance Returns: A dictionary with keys as the genie ids and values as the code segments """ return self._cache.get_all_code_segments()
@staticmethod def _get_data_key(data): if isinstance(data, pd.DataFrame): return "df" return "data"