56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
|
from contextlib import suppress
|
||
|
|
||
|
from docutils import nodes
|
||
|
from docutils.parsers.rst import Directive
|
||
|
|
||
|
from sklearn.utils import all_estimators
|
||
|
from sklearn.utils._testing import SkipTest
|
||
|
from sklearn.utils.estimator_checks import _construct_instance
|
||
|
|
||
|
|
||
|
class AllowNanEstimators(Directive):
|
||
|
@staticmethod
|
||
|
def make_paragraph_for_estimator_type(estimator_type):
|
||
|
intro = nodes.list_item()
|
||
|
intro += nodes.strong(text="Estimators that allow NaN values for type ")
|
||
|
intro += nodes.literal(text=f"{estimator_type}")
|
||
|
intro += nodes.strong(text=":\n")
|
||
|
exists = False
|
||
|
lst = nodes.bullet_list()
|
||
|
for name, est_class in all_estimators(type_filter=estimator_type):
|
||
|
with suppress(SkipTest):
|
||
|
est = _construct_instance(est_class)
|
||
|
|
||
|
if est._get_tags().get("allow_nan"):
|
||
|
module_name = ".".join(est_class.__module__.split(".")[:2])
|
||
|
class_title = f"{est_class.__name__}"
|
||
|
class_url = f"./generated/{module_name}.{class_title}.html"
|
||
|
item = nodes.list_item()
|
||
|
para = nodes.paragraph()
|
||
|
para += nodes.reference(
|
||
|
class_title, text=class_title, internal=False, refuri=class_url
|
||
|
)
|
||
|
exists = True
|
||
|
item += para
|
||
|
lst += item
|
||
|
intro += lst
|
||
|
return [intro] if exists else None
|
||
|
|
||
|
def run(self):
|
||
|
lst = nodes.bullet_list()
|
||
|
for i in ["cluster", "regressor", "classifier", "transformer"]:
|
||
|
item = self.make_paragraph_for_estimator_type(i)
|
||
|
if item is not None:
|
||
|
lst += item
|
||
|
return [lst]
|
||
|
|
||
|
|
||
|
def setup(app):
|
||
|
app.add_directive("allow_nan_estimators", AllowNanEstimators)
|
||
|
|
||
|
return {
|
||
|
"version": "0.1",
|
||
|
"parallel_read_safe": True,
|
||
|
"parallel_write_safe": True,
|
||
|
}
|