summaryrefslogtreecommitdiff
path: root/genderbench/scripts/estimating_num_repetitions.py
diff options
context:
space:
mode:
Diffstat (limited to 'genderbench/scripts/estimating_num_repetitions.py')
-rw-r--r--genderbench/scripts/estimating_num_repetitions.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/genderbench/scripts/estimating_num_repetitions.py b/genderbench/scripts/estimating_num_repetitions.py
new file mode 100644
index 0000000..9a233bc
--- /dev/null
+++ b/genderbench/scripts/estimating_num_repetitions.py
@@ -0,0 +1,62 @@
+"""
+This script shows an example of how it is possible to increase `num_repetitions`
+one by one. This can be used to make judgments about the optimal value of
+repetitions for individual probes, i.e., you can obsereve the range of the CI
+interval `mx-mn` and see how it changes when the number of repetitions is
+increased.
+
+It is not recommended to use `RandomGenerator` for this use case as it does not
+simulate real LM distribution.
+"""
+
+from genderbench.generators.random import RandomGenerator
+from genderbench.probes.gest.gest_probe import GestProbe
+from genderbench.probing.probe import status
+
+
+def probe_factory():
+ return GestProbe(
+ template=GestProbe.templates[0],
+ num_reorderings=1,
+ calculate_cis=True,
+ )
+
+
+generator = RandomGenerator(["(a)", "(b)", "(c)"])
+
+metric_of_interest = "stereotype_rate"
+
+main_probe = probe_factory()
+
+assert main_probe.calculate_cis
+
+main_probe.run(generator)
+mn, mx = main_probe.metrics[metric_of_interest]
+print("Reps=1", mn, mx, mx - mn)
+
+for i in range(10):
+
+ new_probe = probe_factory()
+ new_probe.calculate_cis = False
+ new_probe.run(generator)
+ for main_item, new_item in zip(main_probe.probe_items, new_probe.probe_items):
+ for attempt in new_item.attempts:
+ attempt.repetition_id = i + 1
+ main_item.attempts.append(attempt)
+
+ del new_probe
+
+ main_item.num_repetitions += 1
+ main_probe.status = status.EVALUATED
+
+ # Clear cache in case metric calculator uses it
+ obj = main_probe.metric_calculator
+ for attr_name in dir(obj):
+ attr = getattr(obj, attr_name)
+ if callable(attr) and hasattr(attr, "cache_clear"):
+ attr.cache_clear()
+
+ main_probe.calculate_metrics()
+
+ mn, mx = main_probe.metrics[metric_of_interest]
+ print(f"Reps={i + 2}", mn, mx, mx - mn)