import random
import re
from typing import Any, Callable, Dict, List, Optional, Union
import pandas as pd
from pydantic import BaseModel
from code_genie._cache import _CacheManager, _CacheValue
from code_genie.client import Client
[docs]class GenieResult(BaseModel):
"""The result of a genie execution"""
id: str
"""ID of the genie, this would also be the filename used for storing the generated code in the cache."""
code: str
"""The code generated by the genie"""
cache_dir: str
"""The cache directory used by the genie"""
result: Any = None
"""The result of the execution; None if no result was returned"""
class Config:
# always exclude result from json export
fields = {"result": {"exclude": True}}
frozen = True
[docs]class Genie:
_hash_sep = "::"
[docs] def __init__(
self,
data: Optional[Any] = None,
client: Optional[Client] = None,
cache_dir: Optional[str] = None,
copy_data_before_use: bool = True,
):
"""Initialize a genie instance
Args:
data: a base dataset whose attributes will be used to generate the code. the result will be determined
by running this data over the code
client: an instance of the client to use for making requests to the api. if not provided, a new instance
will be created.
cache_dir: if provided, the code generated by the genie will be cached in this directory. if not
provided, the global default is used. it is recommended to use set_cache_dir() method to set this.
copy_data_before_use: if True, the data will be copied before passing through generated code. this is to
prevent the data from being modified inplace by the code. the data passed should have a copy() method
implemented. if False, the data will be passed as is. this is faster but can lead to unexpected results
Returns:
A callable which can be used to execute the code generated by the genie.
"""
self.data = data
self._base_key = self._get_data_key(data)
self._cache = _CacheManager(cache_dir)
self.copy_data_before_use = copy_data_before_use
if copy_data_before_use:
# check data should have a copy method
if not hasattr(data, "copy"):
raise ValueError(
"data should have a copy method implemented if copy_data_before_use is True",
"Set it to False if you want to continue using the genie",
)
self._client = client or Client()
[docs] def plz(
self,
instructions: Optional[Union[str, List[str]]],
additional_inputs: Optional[Dict[str, Any]] = None,
override: bool = False,
update_base_input: bool = False,
) -> GenieResult:
"""Generate code for a new task
Args:
instructions: text instructions on the task required to be performed. use the keywords in inputs argument
to refer to the inputs.
additional_inputs: a dictionary of inputs to the function. the keys are the names of the inputs and the
values are small description of the inputs.
override: if a genie has been generated before with the same args, then it will be loaded from cache be
default. set override to True to make a new API call and recreate the genie.
update_base_input: if True, the base data will be replaced by the result of executing the code. this is used
if we are making a permanent update to the input and want to use the updated input moving forward.
Returns:
A GenieResult instance which contains attributes:
- result: the result of executing the code
- id: the id of the genie
- code: the code generated by the genie
- cache_dir: the directory where the code is cached. the code will be cached in a file named
"cache_dir/<id>.py
"""
if isinstance(instructions, str):
instructions = [instructions]
# check cache
cache_key = self._get_hash_str(instructions, additional_inputs)
cache_value = self._cache.get(cache_key)
# case: reading from cache
if (not override) and (cache_value is not None):
code, id = cache_value.code, cache_value.id
print(f"Loading cached genie id: {id}, set override = True to rerun")
else:
# case: creating new genie
inputs = self._combine_inputs(additional_inputs)
code = self._get_code(instructions, inputs)
id = self._generate_id(code)
self._update_cache(code, cache_key, instructions, inputs, id)
return self._get_result(code, additional_inputs, update_base_input, id)
def _update_cache(
self, code: str, cache_key: str, instructions: Optional[Union[str, List[str]]], inputs: Dict[str, Any], id: str
):
self._cache.update(
cache_key,
_CacheValue(code=code, id=id, instructions=instructions, inputs=list(inputs.keys())),
)
print(f"Genie cached with id: {id}")
def _get_result(self, code, additional_inputs, update_base_input, id: Optional[str] = None):
# create executor and return results
result = self.run(code, additional_inputs)
if update_base_input:
if result is None:
raise ValueError(f"result of genie is None, cannot update base input")
self.data = result
id = id or self._generate_id(code)
return GenieResult(id=id, code=code, cache_dir=self._cache.cache_dir, result=result)
[docs] def custom(
self, code: str, additional_inputs: Optional[Dict[str, Any]] = None, update_base_input: bool = False
) -> GenieResult:
"""Define a custom genie with user defined code segment. The first argument of the function should be the
base input of the genie.
Note that this code should define a stand alone function, ie, it should not depend on any external variables
or functions or imports. If any additional packages are required, you need to import them in the code segment
itself.
Args:
code: the code segment defining a single function to be used to process data.
additional_inputs: a dictionary of inputs to the function. the keys are the names of the inputs and the
values are small description of the inputs.
update_base_input: if True, the base data will be replaced by the result of executing the code. this is used
if we are making a permanent update to the input and want to use the updated input moving forward.
Returns:
A GenieResult instance which contains attributes:
- result: the result of executing the code
- id: the id of the genie
- code: the code generated by the genie
- cache_dir: the directory where the code is cached. the code will be cached in a file named
"cache_dir/<id>.py
"""
# proxy instructions as the code entered
instructions = [code]
cache_key = self._get_hash_str(instructions, additional_inputs)
id = self._generate_id(code)
self._update_cache(code, cache_key, instructions, inputs=self._combine_inputs(additional_inputs), id=id)
return self._get_result(code, additional_inputs, update_base_input)
def run(self, code: str, additional_inputs: Dict[str, Any]):
executor = self._extract_executable(code)
try:
return executor(**self._combine_inputs(additional_inputs, copy_base_input=True))
except Exception as e:
raise RuntimeError(f"Failed to execute code segment: \n\n{code}") from e
@classmethod
def _extract_fn_name(cls, code: str):
# find function name from code block and substitute with function_name
match = re.search("def\s+(.*)\(.*\).*:", code)
if match is None:
raise RuntimeError(f"Failed to extract function from code block: {code}")
return match.group(1)
def _combine_inputs(
self, additional_inputs: Optional[Dict[str, Any]], copy_base_input: bool = False
) -> Dict[str, Any]:
data = self.data.copy() if (copy_base_input and self.copy_data_before_use) else self.data
return {self._base_key: data, **(additional_inputs or {})}
@staticmethod
def _create_input_str(x):
if isinstance(x, pd.DataFrame):
return f"pandas dataframes with columns: {x.columns}"
return f"{type(x)}"
def _get_code(self, instructions: List[str], inputs: Dict[str, Any]) -> str:
input_str = {key: self._create_input_str(value) for key, value in inputs.items()}
return self._client.get(instructions=instructions, inputs=input_str)
@classmethod
def _extract_executable(cls, code: str) -> Callable:
# define function in memory
fn_name = cls._extract_fn_name(code)
mem = {}
exec(code, mem)
return mem[fn_name]
@classmethod
def _generate_id(cls, code: str) -> str:
fn_name = cls._extract_fn_name(code)
# use fn name with random 5 digit suffix
return f"{fn_name}_{random.randint(10000, 99999)}"
@classmethod
def _list_to_str(cls, l: List[str]) -> str:
return cls._hash_sep.join(l)
@classmethod
def _inputs_to_str(cls, d: Dict[str, Any]) -> str:
sorted_keys = sorted(d.keys())
return cls._hash_sep.join([f"{k}={type(d[k])}" for k in sorted_keys])
def _get_hash_str(self, instructions: List[str], additional_inputs: Optional[Dict[str, Any]]) -> str:
hash_strings = [
self._list_to_str(instructions),
self._inputs_to_str(self._combine_inputs(additional_inputs)),
]
return self._hash_sep.join(hash_strings)
[docs] def read_cache(self) -> Dict[str, str]:
"""Read all the code segments in the cache directory set in the current genie instance
Returns:
A dictionary with keys as the genie ids and values as the code segments
"""
return self._cache.get_all_code_segments()
@staticmethod
def _get_data_key(data):
if isinstance(data, pd.DataFrame):
return "df"
return "data"